first commit
This commit is contained in:
0
framework/changelog.md
Normal file
0
framework/changelog.md
Normal file
8
framework/config.go
Normal file
8
framework/config.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package framework
|
||||
|
||||
import "github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
|
||||
// FrameworkConfig represents the configuration for the framework.
|
||||
type FrameworkConfig struct {
|
||||
Pricing *modelcatalog.Config `json:"pricing,omitempty"`
|
||||
}
|
||||
1239
framework/configstore/clientconfig.go
Normal file
1239
framework/configstore/clientconfig.go
Normal file
File diff suppressed because it is too large
Load Diff
242
framework/configstore/clientconfig_redaction_test.go
Normal file
242
framework/configstore/clientconfig_redaction_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestProviderConfig_Redacted_AutoMasksEnvBackedFields verifies that env-backed
|
||||
// values in any provider config field are automatically redacted in the JSON output
|
||||
// of a Redacted() ProviderConfig — even fields that don't have explicit Redacted()
|
||||
// calls (like Azure APIVersion). This is the defense-in-depth guarantee provided
|
||||
// by EnvVar.MarshalJSON.
|
||||
func TestProviderConfig_Redacted_AutoMasksEnvBackedFields(t *testing.T) {
|
||||
t.Setenv("MY_AZURE_API_VERSION_SECRET", "2024-10-21-preview-secret")
|
||||
|
||||
apiVersion := schemas.NewEnvVar("env.MY_AZURE_API_VERSION_SECRET")
|
||||
require.True(t, apiVersion.IsFromEnv(), "setup: APIVersion should be FromEnv")
|
||||
require.Equal(t, "2024-10-21-preview-secret", apiVersion.GetValue(),
|
||||
"setup: APIVersion should be resolved")
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
AzureKeyConfig: &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
|
||||
APIVersion: apiVersion,
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
require.NotNil(t, redacted)
|
||||
require.Len(t, redacted.Keys, 1)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
|
||||
// Marshal the APIVersion field as it would be sent to the UI.
|
||||
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
EnvVar string `json:"env_var"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(data, &out))
|
||||
|
||||
assert.NotContains(t, out.Value, "preview-secret",
|
||||
"resolved env value leaked through APIVersion JSON output: %q", out.Value)
|
||||
assert.Equal(t, "env.MY_AZURE_API_VERSION_SECRET", out.EnvVar,
|
||||
"env var reference must be preserved so the UI can show it")
|
||||
assert.True(t, out.FromEnv, "from_env flag must be preserved")
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields verifies that the
|
||||
// auto-redaction does NOT touch plain (non-env-backed) values. A user-typed
|
||||
// api_version like "2024-10-21" must show as-is in the UI.
|
||||
func TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields(t *testing.T) {
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
AzureKeyConfig: &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"),
|
||||
APIVersion: schemas.NewEnvVar("2024-10-21"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
require.NotNil(t, redacted)
|
||||
require.Len(t, redacted.Keys, 1)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig)
|
||||
require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
|
||||
data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(data, &out))
|
||||
|
||||
assert.Equal(t, "2024-10-21", out.Value,
|
||||
"plain APIVersion was incorrectly redacted")
|
||||
assert.False(t, out.FromEnv)
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex verifies that
|
||||
// env-backed Vertex fields appear in the redacted output with the env reference
|
||||
// intact and the resolved value masked. This is the user-facing fix for the
|
||||
// "I see resolved env values in the UI" bug.
|
||||
func TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex(t *testing.T) {
|
||||
t.Setenv("MY_VERTEX_PROJECT_ID_SECRET", "super-secret-project-12345")
|
||||
|
||||
projectID := schemas.NewEnvVar("env.MY_VERTEX_PROJECT_ID_SECRET")
|
||||
require.Equal(t, "super-secret-project-12345", projectID.GetValue())
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
VertexKeyConfig: &schemas.VertexKeyConfig{
|
||||
ProjectID: *projectID,
|
||||
Region: *schemas.NewEnvVar("us-central1"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
data, err := json.Marshal(redacted.Keys[0].VertexKeyConfig.ProjectID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out struct {
|
||||
Value string `json:"value"`
|
||||
EnvVar string `json:"env_var"`
|
||||
FromEnv bool `json:"from_env"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(data, &out))
|
||||
|
||||
assert.NotContains(t, out.Value, "super-secret-project",
|
||||
"resolved Vertex ProjectID env value leaked: %q", out.Value)
|
||||
assert.Equal(t, "env.MY_VERTEX_PROJECT_ID_SECRET", out.EnvVar)
|
||||
assert.True(t, out.FromEnv)
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_DoesNotMutateOriginal ensures Redacted() and the
|
||||
// subsequent JSON marshaling do not mutate the original config in memory. The
|
||||
// inference path reads from the in-memory config and calls GetValue() to build
|
||||
// outgoing LLM requests; if Redacted() or MarshalJSON were to mutate state, every
|
||||
// inference request after a UI fetch would silently start using masked values.
|
||||
func TestProviderConfig_Redacted_DoesNotMutateOriginal(t *testing.T) {
|
||||
t.Setenv("MY_REAL_KEY", "sk-real-secret-1234567890abcdef")
|
||||
|
||||
keyValue := schemas.NewEnvVar("env.MY_REAL_KEY")
|
||||
require.Equal(t, "sk-real-secret-1234567890abcdef", keyValue.GetValue())
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{{
|
||||
ID: "k1",
|
||||
Name: "test",
|
||||
Value: *keyValue,
|
||||
}},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
_, err := json.Marshal(redacted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original must still hold the resolved value.
|
||||
assert.Equal(t, "sk-real-secret-1234567890abcdef", config.Keys[0].Value.GetValue(),
|
||||
"Redacted() or MarshalJSON mutated the original key Value")
|
||||
}
|
||||
|
||||
// TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets is a high-level smoke
|
||||
// test: build a config containing env-backed values across multiple provider types
|
||||
// and assert that no resolved secret string appears anywhere in the marshaled
|
||||
// redacted JSON.
|
||||
func TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets(t *testing.T) {
|
||||
t.Setenv("LEAK_TEST_AZURE_ENDPOINT", "https://leaked-azure.example.com")
|
||||
t.Setenv("LEAK_TEST_AZURE_APIVER", "leaked-api-version-string")
|
||||
t.Setenv("LEAK_TEST_VERTEX_PROJECT", "leaked-vertex-project-id")
|
||||
t.Setenv("LEAK_TEST_BEDROCK_ACCESS", "AKIAIOSFODNN7LEAKED1")
|
||||
t.Setenv("LEAK_TEST_OPENAI_KEY", "sk-leaked-openai-key-1234567890")
|
||||
|
||||
config := ProviderConfig{
|
||||
Keys: []schemas.Key{
|
||||
{
|
||||
ID: "openai-k",
|
||||
Name: "openai",
|
||||
Value: *schemas.NewEnvVar("env.LEAK_TEST_OPENAI_KEY"),
|
||||
},
|
||||
{
|
||||
ID: "azure-k",
|
||||
Name: "azure",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
AzureKeyConfig: &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar("env.LEAK_TEST_AZURE_ENDPOINT"),
|
||||
APIVersion: schemas.NewEnvVar("env.LEAK_TEST_AZURE_APIVER"),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "vertex-k",
|
||||
Name: "vertex",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
VertexKeyConfig: &schemas.VertexKeyConfig{
|
||||
ProjectID: *schemas.NewEnvVar("env.LEAK_TEST_VERTEX_PROJECT"),
|
||||
Region: *schemas.NewEnvVar("us-central1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "bedrock-k",
|
||||
Name: "bedrock",
|
||||
Value: schemas.EnvVar{Val: ""},
|
||||
BedrockKeyConfig: &schemas.BedrockKeyConfig{
|
||||
AccessKey: *schemas.NewEnvVar("env.LEAK_TEST_BEDROCK_ACCESS"),
|
||||
SecretKey: schemas.EnvVar{Val: ""},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
redacted := config.Redacted()
|
||||
data, err := json.Marshal(redacted)
|
||||
require.NoError(t, err)
|
||||
jsonStr := string(data)
|
||||
|
||||
leakedSecrets := []string{
|
||||
"https://leaked-azure.example.com",
|
||||
"leaked-api-version-string",
|
||||
"leaked-vertex-project-id",
|
||||
"AKIAIOSFODNN7LEAKED1",
|
||||
"sk-leaked-openai-key-1234567890",
|
||||
}
|
||||
for _, secret := range leakedSecrets {
|
||||
assert.False(t, strings.Contains(jsonStr, secret),
|
||||
"resolved env secret %q leaked into redacted JSON output", secret)
|
||||
}
|
||||
|
||||
// And the env var references must be present so the UI can render them.
|
||||
expectedRefs := []string{
|
||||
"env.LEAK_TEST_OPENAI_KEY",
|
||||
"env.LEAK_TEST_AZURE_ENDPOINT",
|
||||
"env.LEAK_TEST_AZURE_APIVER",
|
||||
"env.LEAK_TEST_VERTEX_PROJECT",
|
||||
"env.LEAK_TEST_BEDROCK_ACCESS",
|
||||
}
|
||||
for _, ref := range expectedRefs {
|
||||
assert.True(t, strings.Contains(jsonStr, ref),
|
||||
"env var reference %q missing from redacted JSON output", ref)
|
||||
}
|
||||
}
|
||||
67
framework/configstore/config.go
Normal file
67
framework/configstore/config.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ConfigStoreType represents the type of config store.
|
||||
type ConfigStoreType string
|
||||
|
||||
// ConfigStoreTypeSQLite is the type of config store for SQLite.
|
||||
const (
|
||||
ConfigStoreTypeSQLite ConfigStoreType = "sqlite"
|
||||
ConfigStoreTypePostgres ConfigStoreType = "postgres"
|
||||
)
|
||||
|
||||
// Config represents the configuration for the config store.
|
||||
type Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type ConfigStoreType `json:"type"`
|
||||
Config any `json:"config"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the config from JSON.
|
||||
func (c *Config) UnmarshalJSON(data []byte) error {
|
||||
// First, unmarshal into a temporary struct to get the basic fields
|
||||
type TempConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type ConfigStoreType `json:"type"`
|
||||
Config json.RawMessage `json:"config"` // Keep as raw JSON
|
||||
}
|
||||
|
||||
var temp TempConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal config store config: %w", err)
|
||||
}
|
||||
|
||||
// Set basic fields
|
||||
c.Enabled = temp.Enabled
|
||||
c.Type = temp.Type
|
||||
|
||||
if !temp.Enabled {
|
||||
c.Config = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the config field based on type
|
||||
switch temp.Type {
|
||||
case ConfigStoreTypeSQLite:
|
||||
var sqliteConfig SQLiteConfig
|
||||
if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal sqlite config: %w", err)
|
||||
}
|
||||
c.Config = &sqliteConfig
|
||||
case ConfigStoreTypePostgres:
|
||||
var postgresConfig PostgresConfig
|
||||
var err error
|
||||
if err = json.Unmarshal(temp.Config, &postgresConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal postgres config: %w", err)
|
||||
}
|
||||
c.Config = &postgresConfig
|
||||
default:
|
||||
return fmt.Errorf("unknown config store type: %s", temp.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
378
framework/configstore/dlock.go
Normal file
378
framework/configstore/dlock.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// Default lock configuration values
|
||||
const (
|
||||
DefaultLockTTL = 30 * time.Second
|
||||
DefaultRetryInterval = 100 * time.Millisecond
|
||||
DefaultMaxRetries = 100
|
||||
DefaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Lock errors
|
||||
var (
|
||||
ErrLockNotAcquired = errors.New("failed to acquire lock")
|
||||
ErrLockNotHeld = errors.New("lock not held by this holder")
|
||||
ErrLockExpired = errors.New("lock has expired")
|
||||
ErrEmptyLockKey = errors.New("empty lock key")
|
||||
)
|
||||
|
||||
// LockStore defines the storage operations required for distributed locking.
|
||||
// This interface abstracts the database operations, making the lock implementation
|
||||
// testable and decoupled from the specific database implementation.
|
||||
type LockStore interface {
|
||||
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
|
||||
// If the lock already exists and is not expired, returns false.
|
||||
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
|
||||
|
||||
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
|
||||
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
|
||||
|
||||
// UpdateLockExpiry updates the expiration time for an existing lock.
|
||||
// Only succeeds if the holder ID matches the current lock holder.
|
||||
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
|
||||
|
||||
// ReleaseLock deletes a lock if the holder ID matches.
|
||||
// Returns true if the lock was released, false if it wasn't held by the given holder.
|
||||
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
|
||||
|
||||
// CleanupExpiredLocks removes all locks that have expired.
|
||||
// Returns the number of locks cleaned up.
|
||||
CleanupExpiredLocks(ctx context.Context) (int64, error)
|
||||
|
||||
// CleanupExpiredLockByKey atomically deletes a lock only if it has expired.
|
||||
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
|
||||
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
|
||||
}
|
||||
|
||||
// DistributedLockManager creates and manages distributed locks.
|
||||
// It provides a factory for creating locks with consistent configuration.
|
||||
type DistributedLockManager struct {
|
||||
store LockStore
|
||||
logger schemas.Logger
|
||||
defaultTTL time.Duration
|
||||
retryInterval time.Duration
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
// DistributedLockManagerOption is a function that configures a DistributedLockManager.
|
||||
type DistributedLockManagerOption func(*DistributedLockManager)
|
||||
|
||||
// WithDefaultTTL sets the default TTL for locks created by this manager.
|
||||
func WithDefaultTTL(ttl time.Duration) DistributedLockManagerOption {
|
||||
return func(m *DistributedLockManager) {
|
||||
m.defaultTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetryInterval sets the interval between lock acquisition retries.
|
||||
func WithRetryInterval(interval time.Duration) DistributedLockManagerOption {
|
||||
return func(m *DistributedLockManager) {
|
||||
m.retryInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxRetries sets the maximum number of retries for lock acquisition.
|
||||
func WithMaxRetries(maxRetries int) DistributedLockManagerOption {
|
||||
return func(m *DistributedLockManager) {
|
||||
m.maxRetries = maxRetries
|
||||
}
|
||||
}
|
||||
|
||||
// NewDistributedLockManager creates a new lock manager with the given store and options.
|
||||
func NewDistributedLockManager(store LockStore, logger schemas.Logger, opts ...DistributedLockManagerOption) *DistributedLockManager {
|
||||
m := &DistributedLockManager{
|
||||
store: store,
|
||||
logger: logger,
|
||||
defaultTTL: DefaultLockTTL,
|
||||
retryInterval: DefaultRetryInterval,
|
||||
maxRetries: DefaultMaxRetries,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// NewLock creates a new DistributedLock for the given key.
|
||||
// The lock is not acquired until Lock() or TryLock() is called.
|
||||
// Returns an error if the lock key is empty.
|
||||
func (m *DistributedLockManager) NewLock(lockKey string) (*DistributedLock, error) {
|
||||
if lockKey == "" {
|
||||
return nil, ErrEmptyLockKey
|
||||
}
|
||||
return &DistributedLock{
|
||||
store: m.store,
|
||||
logger: m.logger,
|
||||
lockKey: lockKey,
|
||||
holderID: uuid.New().String(),
|
||||
ttl: m.defaultTTL,
|
||||
retryInterval: m.retryInterval,
|
||||
maxRetries: m.maxRetries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewLockWithTTL creates a new DistributedLock with a custom TTL.
|
||||
// Returns an error if the lock key is empty.
|
||||
func (m *DistributedLockManager) NewLockWithTTL(lockKey string, ttl time.Duration) (*DistributedLock, error) {
|
||||
lock, err := m.NewLock(lockKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lock.ttl = ttl
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
// CleanupExpiredLocks removes all expired locks from the store.
|
||||
// This can be called periodically to clean up stale locks.
|
||||
func (m *DistributedLockManager) CleanupExpiredLocks(ctx context.Context) (int64, error) {
|
||||
return m.store.CleanupExpiredLocks(ctx)
|
||||
}
|
||||
|
||||
// DistributedLock represents a distributed lock that can be acquired and released
|
||||
// across multiple processes or instances.
|
||||
type DistributedLock struct {
|
||||
store LockStore
|
||||
logger schemas.Logger
|
||||
lockKey string
|
||||
holderID string
|
||||
ttl time.Duration
|
||||
retryInterval time.Duration
|
||||
maxRetries int
|
||||
acquired bool
|
||||
}
|
||||
|
||||
// Lock acquires the lock, blocking until it's available or the context is cancelled.
|
||||
// It will make up to (maxRetries + 1) attempts, sleeping retryInterval between failed attempts.
|
||||
func (l *DistributedLock) Lock(ctx context.Context) error {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i <= l.maxRetries; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
acquired, err := l.TryLock(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error acquiring lock: %w", err)
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait before retrying
|
||||
if i < l.maxRetries {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(l.retryInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ErrLockNotAcquired
|
||||
}
|
||||
|
||||
// LockWithRetry acquires the lock, blocking until it's available or the context is cancelled.
|
||||
// It will retry up to maxRetries times with retryInterval between attempts.
|
||||
func (l *DistributedLock) LockWithRetry(ctx context.Context, maxRetries int) error {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i <= maxRetries; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
acquired, err := l.TryLock(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error acquiring lock: %w", err)
|
||||
}
|
||||
if acquired {
|
||||
return nil
|
||||
}
|
||||
// Wait before retrying
|
||||
if i < maxRetries {
|
||||
// Exponential backoff capped to avoid overflow (max 32s).
|
||||
exp := i
|
||||
if exp > 5 {
|
||||
exp = 5
|
||||
}
|
||||
backoff := time.Duration(1<<uint(exp)) * time.Second
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
return ErrLockNotAcquired
|
||||
}
|
||||
|
||||
// TryLock attempts to acquire the lock without blocking.
|
||||
// Returns true if the lock was acquired, false if it's held by another process.
|
||||
func (l *DistributedLock) TryLock(ctx context.Context) (bool, error) {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return true, nil
|
||||
}
|
||||
// First, try to clean up any expired locks for this key
|
||||
if err := l.cleanupExpiredLock(ctx); err != nil {
|
||||
l.logger.Debug("error cleaning up expired lock: %v", err)
|
||||
}
|
||||
|
||||
lock := &tables.TableDistributedLock{
|
||||
LockKey: l.lockKey,
|
||||
HolderID: l.holderID,
|
||||
ExpiresAt: time.Now().UTC().Add(l.ttl),
|
||||
}
|
||||
|
||||
acquired, err := l.store.TryAcquireLock(ctx, lock)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error trying to acquire lock: %w", err)
|
||||
}
|
||||
|
||||
if acquired {
|
||||
l.acquired = true
|
||||
l.logger.Debug("acquired lock %s with holder %s", l.lockKey, l.holderID)
|
||||
}
|
||||
|
||||
return acquired, nil
|
||||
}
|
||||
|
||||
// Unlock releases the lock if it's held by this holder.
|
||||
// Returns an error if the lock is not held by this holder.
|
||||
func (l *DistributedLock) Unlock(ctx context.Context) error {
|
||||
// if config_store is not present, return nil (no-op)
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
if !l.acquired {
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
released, err := l.store.ReleaseLock(ctx, l.lockKey, l.holderID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error releasing lock: %w", err)
|
||||
}
|
||||
|
||||
if !released {
|
||||
l.acquired = false
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
l.acquired = false
|
||||
l.logger.Debug("released lock %s", l.lockKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extend extends the lock's TTL. This is useful for long-running operations
|
||||
// that need to hold the lock longer than the initial TTL.
|
||||
// Returns an error if the lock is not held by this holder or has expired.
|
||||
// Only clears l.acquired when ErrLockNotHeld is returned; transient errors
|
||||
// leave l.acquired untouched so Unlock() can still attempt a proper release.
|
||||
func (l *DistributedLock) Extend(ctx context.Context) error {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
// if lock is not acquired, return error
|
||||
if !l.acquired {
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
newExpiresAt := time.Now().UTC().Add(l.ttl)
|
||||
if err := l.store.UpdateLockExpiry(ctx, l.lockKey, l.holderID, newExpiresAt); err != nil {
|
||||
if errors.Is(err, ErrLockNotHeld) {
|
||||
// Lock definitively not held - clear local state
|
||||
l.acquired = false
|
||||
}
|
||||
// Otherwise leave l.acquired untouched for transient errors
|
||||
return fmt.Errorf("error extending lock: %w", err)
|
||||
}
|
||||
|
||||
l.logger.Debug("extended lock %s to %v", l.lockKey, newExpiresAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHeld checks if the lock is currently held by this holder.
|
||||
// Note: This checks the local state and the database state.
|
||||
// Returns (false, error) on transient database errors without clearing l.acquired,
|
||||
// allowing Unlock() to still attempt a proper release.
|
||||
func (l *DistributedLock) IsHeld(ctx context.Context) (bool, error) {
|
||||
// if config_store is not present, return true
|
||||
if l.store == nil {
|
||||
return false, nil
|
||||
}
|
||||
if !l.acquired {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
lock, err := l.store.GetLock(ctx, l.lockKey)
|
||||
if err != nil {
|
||||
// Transient error - can't confirm state, leave l.acquired untouched
|
||||
return false, fmt.Errorf("error checking lock: %w", err)
|
||||
}
|
||||
|
||||
if lock == nil {
|
||||
// Lock doesn't exist - definitively not held
|
||||
l.acquired = false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if we're still the holder and the lock hasn't expired
|
||||
if lock.HolderID != l.holderID || time.Now().UTC().After(lock.ExpiresAt) {
|
||||
l.acquired = false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Key returns the lock key.
|
||||
func (l *DistributedLock) Key() string {
|
||||
return l.lockKey
|
||||
}
|
||||
|
||||
// HolderID returns the unique identifier for this lock holder.
|
||||
func (l *DistributedLock) HolderID() string {
|
||||
return l.holderID
|
||||
}
|
||||
|
||||
// cleanupExpiredLock atomically removes the lock if it has expired.
|
||||
// This is called before attempting to acquire a lock.
|
||||
func (l *DistributedLock) cleanupExpiredLock(ctx context.Context) error {
|
||||
// if config_store is not present, return nil
|
||||
if l.store == nil {
|
||||
return nil
|
||||
}
|
||||
cleaned, err := l.store.CleanupExpiredLockByKey(ctx, l.lockKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error cleaning up expired lock: %w", err)
|
||||
}
|
||||
|
||||
if cleaned {
|
||||
l.logger.Debug("cleaned up expired lock %s", l.lockKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1223
framework/configstore/dlock_test.go
Normal file
1223
framework/configstore/dlock_test.go
Normal file
File diff suppressed because it is too large
Load Diff
369
framework/configstore/encryption.go
Normal file
369
framework/configstore/encryption.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
encryptionStatusPlainText = "plain_text"
|
||||
encryptionStatusEncrypted = "encrypted"
|
||||
encryptionBatchSize = 100
|
||||
)
|
||||
|
||||
// EncryptPlaintextRows encrypts all rows with encryption_status='plain_text'
|
||||
// across all sensitive tables. Called during startup when encryption is enabled.
|
||||
// Each table's GORM BeforeSave hook handles the actual encryption.
|
||||
func (s *RDBConfigStore) EncryptPlaintextRows(ctx context.Context) error {
|
||||
if !encrypt.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var totalEncrypted int
|
||||
|
||||
// config_keys
|
||||
count, err := s.encryptPlaintextKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt config_keys: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// governance_virtual_keys
|
||||
count, err = s.encryptPlaintextVirtualKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt virtual_keys: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// sessions
|
||||
count, err = s.encryptPlaintextSessions(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt sessions: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// oauth_tokens
|
||||
count, err = s.encryptPlaintextOAuthTokens(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth_tokens: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// oauth_configs
|
||||
count, err = s.encryptPlaintextOAuthConfigs(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth_configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_mcp_clients
|
||||
count, err = s.encryptPlaintextMCPClients(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt mcp_clients: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_providers (proxy config)
|
||||
count, err = s.encryptPlaintextProviderProxies(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt provider proxy configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_vector_store
|
||||
count, err = s.encryptPlaintextVectorStoreConfigs(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt vector_store configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
// config_plugins
|
||||
count, err = s.encryptPlaintextPlugins(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt plugin configs: %w", err)
|
||||
}
|
||||
totalEncrypted += count
|
||||
|
||||
if totalEncrypted > 0 && s.logger != nil {
|
||||
s.logger.Info(fmt.Sprintf("encrypted %d plaintext rows across all tables", totalEncrypted))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPlaintextKeys finds all config_keys rows with plaintext encryption status and
|
||||
// re-saves them in batches. The TableKey.BeforeSave hook handles the actual encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextKeys(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var keys []tables.TableKey
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&keys).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range keys {
|
||||
if err := tx.Save(&keys[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(keys)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextVirtualKeys finds all governance_virtual_keys rows with plaintext encryption
|
||||
// status and re-saves them in batches. The TableVirtualKey.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextVirtualKeys(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var vks []tables.TableVirtualKey
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND value != ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&vks).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(vks) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range vks {
|
||||
if err := tx.Save(&vks[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(vks)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextSessions finds all sessions rows with plaintext encryption status and
|
||||
// re-saves them in batches. The SessionsTable.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextSessions(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var sessions []tables.SessionsTable
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND token != ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&sessions).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range sessions {
|
||||
if err := tx.Save(&sessions[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(sessions)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextOAuthTokens finds all oauth_tokens rows with plaintext encryption status
|
||||
// and re-saves them in batches. The TableOauthToken.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextOAuthTokens(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var tokens []tables.TableOauthToken
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&tokens).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(tokens) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range tokens {
|
||||
if err := tx.Save(&tokens[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(tokens)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextOAuthConfigs finds all oauth_configs rows with plaintext encryption status
|
||||
// and re-saves them in batches. The TableOauthConfig.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextOAuthConfigs(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var configs []tables.TableOauthConfig
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND (client_secret != '' OR code_verifier != '')", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&configs).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range configs {
|
||||
if err := tx.Save(&configs[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(configs)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextMCPClients finds all config_mcp_clients rows with plaintext encryption
|
||||
// status and re-saves them in batches. The TableMCPClient.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextMCPClients(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var clients []tables.TableMCPClient
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("encryption_status = ? OR encryption_status IS NULL OR encryption_status = ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&clients).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range clients {
|
||||
if err := tx.Save(&clients[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(clients)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextProviderProxies finds all config_providers rows that have a non-empty
|
||||
// proxy config with plaintext encryption status and re-saves them in batches. The
|
||||
// TableProvider.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextProviderProxies(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var providers []tables.TableProvider
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND proxy_config_json != '' AND proxy_config_json IS NOT NULL", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&providers).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range providers {
|
||||
if err := tx.Save(&providers[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(providers)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextVectorStoreConfigs finds all config_vector_store rows that have a non-empty
|
||||
// config with plaintext encryption status and re-saves them in batches. The
|
||||
// TableVectorStoreConfig.BeforeSave hook handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextVectorStoreConfigs(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var configs []tables.TableVectorStoreConfig
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config IS NOT NULL AND config != ''", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&configs).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(configs) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range configs {
|
||||
if err := tx.Save(&configs[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(configs)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextPlugins finds all config_plugins rows that have a non-empty config with
|
||||
// plaintext encryption status and re-saves them in batches. The TablePlugin.BeforeSave hook
|
||||
// handles encryption.
|
||||
func (s *RDBConfigStore) encryptPlaintextPlugins(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
for {
|
||||
var plugins []tables.TablePlugin
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Where("(encryption_status = ? OR encryption_status IS NULL OR encryption_status = '') AND config_json != '' AND config_json != '{}'", encryptionStatusPlainText).
|
||||
Limit(encryptionBatchSize).
|
||||
Find(&plugins).Error; err != nil {
|
||||
return count, err
|
||||
}
|
||||
if len(plugins) == 0 {
|
||||
break
|
||||
}
|
||||
if err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for i := range plugins {
|
||||
if err := tx.Save(&plugins[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count += len(plugins)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
1397
framework/configstore/encryption_test.go
Normal file
1397
framework/configstore/encryption_test.go
Normal file
File diff suppressed because it is too large
Load Diff
19
framework/configstore/errors.go
Normal file
19
framework/configstore/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrAlreadyExists = errors.New("already exists")
|
||||
|
||||
// ErrUnresolvedKeys is returned when one or more keys could not be resolved
|
||||
type ErrUnresolvedKeys struct {
|
||||
Identifiers []string
|
||||
}
|
||||
|
||||
func (e *ErrUnresolvedKeys) Error() string {
|
||||
return fmt.Sprintf("could not resolve keys: %s", strings.Join(e.Identifiers, ", "))
|
||||
}
|
||||
45
framework/configstore/logger.go
Normal file
45
framework/configstore/logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
gormLibLogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// GormLogger is a logger for GORM.
|
||||
type gormLogger struct {
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// LogMode sets the log mode for the logger.
|
||||
func (l *gormLogger) LogMode(level gormLibLogger.LogLevel) gormLibLogger.Interface {
|
||||
// NOOP
|
||||
return l
|
||||
}
|
||||
|
||||
// Info logs an info message.
|
||||
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Info(msg, data...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message.
|
||||
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Warn(msg, data...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Error(msg, data...)
|
||||
}
|
||||
|
||||
// Trace logs a trace message.
|
||||
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
// NOOP
|
||||
}
|
||||
|
||||
// newGormLogger creates a new GormLogger.
|
||||
func newGormLogger(l schemas.Logger) *gormLogger {
|
||||
return &gormLogger{logger: l}
|
||||
}
|
||||
7004
framework/configstore/migrations.go
Normal file
7004
framework/configstore/migrations.go
Normal file
File diff suppressed because it is too large
Load Diff
2373
framework/configstore/migrations_test.go
Normal file
2373
framework/configstore/migrations_test.go
Normal file
File diff suppressed because it is too large
Load Diff
169
framework/configstore/postgres.go
Normal file
169
framework/configstore/postgres.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostgresConfig represents the configuration for a Postgres database.
|
||||
type PostgresConfig struct {
|
||||
Host *schemas.EnvVar `json:"host"`
|
||||
Port *schemas.EnvVar `json:"port"`
|
||||
User *schemas.EnvVar `json:"user"`
|
||||
Password *schemas.EnvVar `json:"password"`
|
||||
DBName *schemas.EnvVar `json:"db_name"`
|
||||
SSLMode *schemas.EnvVar `json:"ssl_mode"`
|
||||
MaxIdleConns int `json:"max_idle_conns"`
|
||||
MaxOpenConns int `json:"max_open_conns"`
|
||||
}
|
||||
|
||||
// buildPostgresDSN assembles a libpq-style DSN from the validated config.
|
||||
func buildPostgresDSN(config *PostgresConfig) string {
|
||||
return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
|
||||
config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(),
|
||||
config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue())
|
||||
}
|
||||
|
||||
// openPostresConnection opens a *gorm.DB against the configured Postgres instance
|
||||
// using the shared bifrost logger. Used for both the throwaway migration pool
|
||||
// and the runtime pool.
|
||||
func openPostresConnection(dsn string, logger schemas.Logger) (*gorm.DB, error) {
|
||||
return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{
|
||||
Logger: newGormLogger(logger),
|
||||
})
|
||||
}
|
||||
|
||||
// closeDbConn closes the *sql.DB backing a *gorm.DB, logging any error.
|
||||
// Used in error paths and for the throwaway migration pool.
|
||||
func closeDbConn(db *gorm.DB, logger schemas.Logger) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
logger.Error("failed to resolve *sql.DB for close: %v", err)
|
||||
return
|
||||
}
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
logger.Error("failed to close DB connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// applyPostgresPoolTuning applies MaxIdleConns / MaxOpenConns from config to
|
||||
// the supplied *gorm.DB, falling back to defaults when the config leaves the
|
||||
// field at zero.
|
||||
func applyPostgresPoolTuning(db *gorm.DB, config *PostgresConfig) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
maxIdleConns := config.MaxIdleConns
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = 5
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(maxIdleConns)
|
||||
maxOpenConns := config.MaxOpenConns
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 50
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(maxOpenConns)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newPostgresConfigStore creates a new Postgres config store.
|
||||
//
|
||||
// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not
|
||||
// change result type"): a throwaway migration pool runs DDL and is closed
|
||||
// immediately, then a fresh runtime pool is opened. The runtime pool's
|
||||
// connections never see pre-migration schema, so their cached prepared-plans
|
||||
// stay valid for the life of the process.
|
||||
func newPostgresConfigStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (ConfigStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if config.Host == nil || config.Host.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres host is required")
|
||||
}
|
||||
if config.Port == nil || config.Port.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres port is required")
|
||||
}
|
||||
if config.User == nil || config.User.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres user is required")
|
||||
}
|
||||
if config.Password == nil {
|
||||
return nil, fmt.Errorf("postgres password is required")
|
||||
}
|
||||
if config.DBName == nil || config.DBName.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres db name is required")
|
||||
}
|
||||
if config.SSLMode == nil || config.SSLMode.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres ssl mode is required")
|
||||
}
|
||||
dsn := buildPostgresDSN(config)
|
||||
|
||||
// Throwaway pool for schema migrations. Closing it before the runtime pool
|
||||
// opens guarantees no cached prepared-plan survives the DDL.
|
||||
mDb, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := triggerMigrations(ctx, mDb); err != nil {
|
||||
closeDbConn(mDb, logger)
|
||||
return nil, err
|
||||
}
|
||||
closeDbConn(mDb, logger)
|
||||
|
||||
// Runtime pool. Opens against post-migration schema.
|
||||
db, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := applyPostgresPoolTuning(db, config); err != nil {
|
||||
closeDbConn(db, logger)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := &RDBConfigStore{logger: logger}
|
||||
d.db.Store(db)
|
||||
|
||||
// migrateOnFreshFn: downstream consumers (e.g. bifrost-enterprise) run
|
||||
// their migrations via this hook on a throwaway pool that closes after fn.
|
||||
d.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
|
||||
tempDB, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closeDbConn(tempDB, logger)
|
||||
return fn(ctx, tempDB)
|
||||
}
|
||||
|
||||
// refreshPoolFn: open fresh runtime pool first (so a failure leaves the
|
||||
// existing pool in place), swap atomically, then close the old pool.
|
||||
// sql.DB.Close blocks until in-flight queries finish, so callers already
|
||||
// using the old pool complete safely.
|
||||
d.refreshPoolFn = func(ctx context.Context) error {
|
||||
newDB, err := openPostresConnection(dsn, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open fresh runtime pool: %w", err)
|
||||
}
|
||||
if err := applyPostgresPoolTuning(newDB, config); err != nil {
|
||||
closeDbConn(newDB, logger)
|
||||
return fmt.Errorf("failed to tune fresh runtime pool: %w", err)
|
||||
}
|
||||
oldDB := d.db.Swap(newDB)
|
||||
if oldDB != nil {
|
||||
closeDbConn(oldDB, logger)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encrypt any plaintext rows if encryption is enabled. Runs on the
|
||||
// runtime pool — pure DML (SELECT + UPDATE), no DDL, so cached plans it
|
||||
// installs remain valid until the next external migration batch.
|
||||
if err := d.EncryptPlaintextRows(ctx); err != nil {
|
||||
closeDbConn(db, logger)
|
||||
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
567
framework/configstore/prompts.go
Normal file
567
framework/configstore/prompts.go
Normal file
@@ -0,0 +1,567 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// isUniqueConstraintError checks if the error is a unique constraint violation (SQLite or PostgreSQL)
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "UNIQUE constraint failed") ||
|
||||
strings.Contains(msg, "duplicate key value violates unique constraint")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Folders
|
||||
// ============================================================================
|
||||
|
||||
// GetFolders gets all folders
|
||||
func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) {
|
||||
var folders []tables.TableFolder
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Order("created_at DESC").
|
||||
Find(&folders).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get prompts count for each folder
|
||||
for i := range folders {
|
||||
var count int64
|
||||
if err := s.DB().WithContext(ctx).Model(&tables.TablePrompt{}).Where("folder_id = ?", folders[i].ID).Count(&count).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
folders[i].PromptsCount = int(count)
|
||||
}
|
||||
|
||||
return folders, nil
|
||||
}
|
||||
|
||||
// GetFolderByID gets a folder by ID
|
||||
func (s *RDBConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) {
|
||||
var folder tables.TableFolder
|
||||
if err := s.DB().WithContext(ctx).
|
||||
First(&folder, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &folder, nil
|
||||
}
|
||||
|
||||
// CreateFolder creates a new folder
|
||||
func (s *RDBConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error {
|
||||
return s.DB().WithContext(ctx).Create(folder).Error
|
||||
}
|
||||
|
||||
// UpdateFolder updates a folder
|
||||
func (s *RDBConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error {
|
||||
res := s.DB().WithContext(ctx).Where("id = ?", folder.ID).Save(folder)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
if res.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteFolder deletes a folder and all its child prompts (with their versions, sessions, and messages).
|
||||
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
|
||||
// alter foreign key constraints after table creation.
|
||||
func (s *RDBConfigStore) DeleteFolder(ctx context.Context, id string) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Check folder exists
|
||||
var folder tables.TableFolder
|
||||
if err := tx.First(&folder, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PostgreSQL: ON DELETE CASCADE handles all child deletions
|
||||
if s.DB().Dialector.Name() == "postgres" {
|
||||
return tx.Delete(&folder).Error
|
||||
}
|
||||
|
||||
// SQLite: manual cascade deletion
|
||||
var promptIDs []string
|
||||
if err := tx.Model(&tables.TablePrompt{}).Where("folder_id = ?", id).Pluck("id", &promptIDs).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(promptIDs) > 0 {
|
||||
// Delete version messages
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete versions
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptVersion{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete session messages
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete sessions
|
||||
if err := tx.Where("prompt_id IN ?", promptIDs).Delete(&tables.TablePromptSession{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete prompts
|
||||
if err := tx.Where("folder_id = ?", id).Delete(&tables.TablePrompt{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the folder
|
||||
return tx.Delete(&folder).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Prompts
|
||||
// ============================================================================
|
||||
|
||||
// GetPrompts gets all prompts, optionally filtered by folder ID
|
||||
func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) {
|
||||
var prompts []tables.TablePrompt
|
||||
query := s.DB().WithContext(ctx).
|
||||
Preload("Folder").
|
||||
Order("created_at DESC")
|
||||
|
||||
if folderID != nil {
|
||||
query = query.Where("folder_id = ?", *folderID)
|
||||
}
|
||||
|
||||
if err := query.Find(&prompts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get latest version for each prompt
|
||||
for i := range prompts {
|
||||
var latestVersion tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ? AND is_latest = ?", prompts[i].ID, true).
|
||||
First(&latestVersion).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompts[i].LatestVersion = &latestVersion
|
||||
}
|
||||
}
|
||||
|
||||
return prompts, nil
|
||||
}
|
||||
|
||||
// GetPromptByID gets a prompt by ID with latest version
|
||||
func (s *RDBConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) {
|
||||
var prompt tables.TablePrompt
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Folder").
|
||||
First(&prompt, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get latest version
|
||||
var latestVersion tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ? AND is_latest = ?", prompt.ID, true).
|
||||
First(&latestVersion).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
prompt.LatestVersion = &latestVersion
|
||||
}
|
||||
|
||||
return &prompt, nil
|
||||
}
|
||||
|
||||
// CreatePrompt creates a new prompt
|
||||
func (s *RDBConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
|
||||
return s.DB().WithContext(ctx).Create(prompt).Error
|
||||
}
|
||||
|
||||
// UpdatePrompt updates a prompt
|
||||
func (s *RDBConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error {
|
||||
// Use Select to explicitly include FolderID so GORM writes NULL when it's nil
|
||||
res := s.DB().WithContext(ctx).
|
||||
Model(prompt).
|
||||
Where("id = ?", prompt.ID).
|
||||
Select("Name", "FolderID", "UpdatedAt").
|
||||
Updates(prompt)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
if res.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePrompt deletes a prompt and all its child versions, sessions, and messages.
|
||||
// PostgreSQL uses native ON DELETE CASCADE; SQLite requires manual cascade because it cannot
|
||||
// alter foreign key constraints after table creation.
|
||||
func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Check prompt exists
|
||||
var prompt tables.TablePrompt
|
||||
if err := tx.First(&prompt, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PostgreSQL: ON DELETE CASCADE handles all child deletions
|
||||
if s.DB().Dialector.Name() == "postgres" {
|
||||
return tx.Delete(&prompt).Error
|
||||
}
|
||||
|
||||
// SQLite: manual cascade deletion
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptVersion{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("prompt_id = ?", id).Delete(&tables.TablePromptSession{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Delete(&prompt).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Versions
|
||||
// ============================================================================
|
||||
|
||||
// GetAllPromptVersions returns every version across all prompts in a single query.
|
||||
func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) {
|
||||
var versions []tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Order("prompt_id ASC, version_number DESC").
|
||||
Find(&versions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// GetPromptVersions gets all versions for a prompt
|
||||
func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) {
|
||||
var versions []tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ?", promptID).
|
||||
Order("version_number DESC").
|
||||
Find(&versions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// GetPromptVersionByID gets a version by ID
|
||||
func (s *RDBConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) {
|
||||
var version tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Preload("Prompt").
|
||||
First(&version, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &version, nil
|
||||
}
|
||||
|
||||
// GetLatestPromptVersion gets the latest version for a prompt
|
||||
func (s *RDBConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) {
|
||||
var version tables.TablePromptVersion
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Where("prompt_id = ? AND is_latest = ?", promptID, true).
|
||||
First(&version).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &version, nil
|
||||
}
|
||||
|
||||
// CreatePromptVersion creates a new version and marks it as latest.
|
||||
// Retries on unique constraint conflict (concurrent version_number allocation).
|
||||
func (s *RDBConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error {
|
||||
const maxRetries = 3
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
err := s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Get the next version number
|
||||
var maxVersionNumber int
|
||||
if err := tx.Model(&tables.TablePromptVersion{}).
|
||||
Where("prompt_id = ?", version.PromptID).
|
||||
Select("COALESCE(MAX(version_number), 0)").
|
||||
Scan(&maxVersionNumber).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
version.VersionNumber = maxVersionNumber + 1
|
||||
|
||||
// Mark all existing versions as not latest
|
||||
if err := tx.Model(&tables.TablePromptVersion{}).
|
||||
Where("prompt_id = ?", version.PromptID).
|
||||
Update("is_latest", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark new version as latest
|
||||
version.IsLatest = true
|
||||
|
||||
// Reset IDs and set order index on messages before create (GORM will auto-create associations)
|
||||
for i := range version.Messages {
|
||||
version.Messages[i].ID = 0
|
||||
version.Messages[i].PromptID = version.PromptID
|
||||
version.Messages[i].OrderIndex = i
|
||||
}
|
||||
|
||||
// Create the version (GORM auto-creates associated messages)
|
||||
if err := tx.Create(version).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// Retry on unique constraint conflict, otherwise return immediately
|
||||
if !isUniqueConstraintError(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create prompt version after %d retries due to concurrent version_number conflict", maxRetries)
|
||||
}
|
||||
|
||||
// DeletePromptVersion deletes a version and promotes the previous version to latest if needed.
|
||||
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
|
||||
func (s *RDBConfigStore) DeletePromptVersion(ctx context.Context, id uint) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Get the version to check if it's latest
|
||||
var version tables.TablePromptVersion
|
||||
if err := tx.First(&version, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SQLite: manually delete version messages (PostgreSQL CASCADE handles this)
|
||||
if s.DB().Dialector.Name() != "postgres" {
|
||||
if err := tx.Where("version_id = ?", id).Delete(&tables.TablePromptVersionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the version
|
||||
if err := tx.Delete(&tables.TablePromptVersion{}, "id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this was the latest version, mark the previous one as latest
|
||||
if version.IsLatest {
|
||||
var prevVersion tables.TablePromptVersion
|
||||
if err := tx.Where("prompt_id = ?", version.PromptID).
|
||||
Order("version_number DESC").
|
||||
First(&prevVersion).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Model(&prevVersion).UpdateColumn("is_latest", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Prompt Repository - Sessions
|
||||
// ============================================================================
|
||||
|
||||
// GetPromptSessions gets all sessions for a prompt
|
||||
func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) {
|
||||
var sessions []tables.TablePromptSession
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Preload("Version").
|
||||
Where("prompt_id = ?", promptID).
|
||||
Order("created_at DESC").
|
||||
Find(&sessions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetPromptSessionByID gets a session by ID
|
||||
func (s *RDBConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) {
|
||||
var session tables.TablePromptSession
|
||||
if err := s.DB().WithContext(ctx).
|
||||
Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }).
|
||||
Preload("Prompt").
|
||||
Preload("Version").
|
||||
First(&session, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// CreatePromptSession creates a new session
|
||||
func (s *RDBConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Verify version belongs to the same prompt if set
|
||||
if session.VersionID != nil {
|
||||
var version tables.TablePromptVersion
|
||||
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("version not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if version.PromptID != session.PromptID {
|
||||
return fmt.Errorf("version does not belong to the specified prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// Save messages and clear from session to prevent GORM auto-creating them
|
||||
msgs := session.Messages
|
||||
session.Messages = nil
|
||||
|
||||
// Create the session without associated messages
|
||||
if err := tx.Create(session).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create messages with fresh IDs
|
||||
for i := range msgs {
|
||||
msgs[i].ID = 0 // Ensure new auto-increment ID
|
||||
msgs[i].PromptID = session.PromptID
|
||||
msgs[i].SessionID = session.ID
|
||||
msgs[i].OrderIndex = i
|
||||
if err := tx.Create(&msgs[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
session.Messages = msgs
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdatePromptSession updates a session and its messages
|
||||
func (s *RDBConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Verify version belongs to the same prompt if set
|
||||
if session.VersionID != nil {
|
||||
var version tables.TablePromptVersion
|
||||
if err := tx.First(&version, "id = ?", *session.VersionID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("version not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if version.PromptID != session.PromptID {
|
||||
return fmt.Errorf("version does not belong to the specified prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// Update the session
|
||||
res := tx.Where("id = ?", session.ID).Save(session)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
if res.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
// Delete old messages
|
||||
if err := tx.Where("session_id = ?", session.ID).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create new messages
|
||||
for i := range session.Messages {
|
||||
session.Messages[i].PromptID = session.PromptID
|
||||
session.Messages[i].SessionID = session.ID
|
||||
session.Messages[i].OrderIndex = i
|
||||
session.Messages[i].ID = 0 // Reset ID for new creation
|
||||
if err := tx.Create(&session.Messages[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RenamePromptSession updates only the name of a session
|
||||
func (s *RDBConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error {
|
||||
result := s.DB().WithContext(ctx).Model(&tables.TablePromptSession{}).Where("id = ?", id).Update("name", name)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePromptSession deletes a session and its messages.
|
||||
// PostgreSQL uses native ON DELETE CASCADE for messages; SQLite requires manual cascade.
|
||||
func (s *RDBConfigStore) DeletePromptSession(ctx context.Context, id uint) error {
|
||||
return s.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var session tables.TablePromptSession
|
||||
if err := tx.First(&session, "id = ?", id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// PostgreSQL: ON DELETE CASCADE handles message deletion
|
||||
if s.DB().Dialector.Name() == "postgres" {
|
||||
return tx.Delete(&session).Error
|
||||
}
|
||||
|
||||
// SQLite: manually delete messages first
|
||||
if err := tx.Where("session_id = ?", id).Delete(&tables.TablePromptSessionMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Delete(&session).Error
|
||||
})
|
||||
}
|
||||
4623
framework/configstore/rdb.go
Normal file
4623
framework/configstore/rdb.go
Normal file
File diff suppressed because it is too large
Load Diff
1505
framework/configstore/rdb_test.go
Normal file
1505
framework/configstore/rdb_test.go
Normal file
File diff suppressed because it is too large
Load Diff
62
framework/configstore/sqlite.go
Normal file
62
framework/configstore/sqlite.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SQLiteConfig represents the configuration for a SQLite database.
|
||||
type SQLiteConfig struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// newSqliteConfigStore creates a new SQLite config store.
|
||||
func newSqliteConfigStore(ctx context.Context, config *SQLiteConfig, logger schemas.Logger) (ConfigStore, error) {
|
||||
if _, err := os.Stat(config.Path); os.IsNotExist(err) {
|
||||
// Create DB file
|
||||
f, err := os.Create(config.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = f.Close()
|
||||
}
|
||||
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path)
|
||||
logger.Debug("opening DB with dsn: %s", dsn)
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||
Logger: newGormLogger(logger),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("db opened for configstore")
|
||||
s := &RDBConfigStore{logger: logger}
|
||||
s.db.Store(db)
|
||||
// SQLite has no server-side prepared-plan cache, and opening a second
|
||||
// handle on the same file would contend for the single-writer lock —
|
||||
// so both hooks operate on the existing *gorm.DB.
|
||||
s.migrateOnFreshFn = func(ctx context.Context, fn func(context.Context, *gorm.DB) error) error {
|
||||
return fn(ctx, s.DB())
|
||||
}
|
||||
s.refreshPoolFn = func(ctx context.Context) error { return nil }
|
||||
|
||||
logger.Debug("running migration to remove duplicate keys")
|
||||
// Run migration to remove duplicate keys before AutoMigrate
|
||||
if err := s.removeDuplicateKeysAndNullKeys(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove duplicate keys: %w", err)
|
||||
}
|
||||
// Run migrations
|
||||
if err := triggerMigrations(ctx, db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Encrypt any plaintext rows if encryption is enabled
|
||||
if err := s.EncryptPlaintextRows(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt plaintext rows: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
441
framework/configstore/store.go
Normal file
441
framework/configstore/store.go
Normal file
@@ -0,0 +1,441 @@
|
||||
// Package configstore provides a persistent configuration store for Bifrost.
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/framework/vectorstore"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// VirtualKeyQueryParams holds pagination, filtering, and search parameters for virtual key queries.
|
||||
type VirtualKeyQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
CustomerID string
|
||||
TeamID string
|
||||
SortBy string // name, budget_spent, created_at, status (default: created_at)
|
||||
Order string // asc, desc (default: asc)
|
||||
Export bool // When true, skip default pagination limits (caller controls limit)
|
||||
ExcludeAccessProfileManagedVirtual bool // When true, exclude VKs managed through enterprise access profiles
|
||||
}
|
||||
|
||||
// ModelConfigsQueryParams holds pagination, filtering, and search parameters for model configs queries.
|
||||
type ModelConfigsQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// RoutingRulesQueryParams holds pagination, filtering, and search parameters for routing rules queries.
|
||||
type RoutingRulesQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// MCPClientsQueryParams holds pagination, filtering, and search parameters for MCP client queries.
|
||||
type MCPClientsQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// TeamsQueryParams holds pagination, filtering, and search parameters for team queries.
|
||||
type TeamsQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
// CustomersQueryParams holds pagination, filtering, and search parameters for customer queries.
|
||||
type CustomersQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
}
|
||||
|
||||
// PricingOverrideFilters holds the filters for pricing overrides.
|
||||
type PricingOverrideFilters struct {
|
||||
ScopeKind *string
|
||||
VirtualKeyID *string
|
||||
ProviderID *string
|
||||
ProviderKeyID *string
|
||||
}
|
||||
|
||||
// PricingOverridesQueryParams holds pagination, filtering, and search parameters for pricing override queries.
|
||||
type PricingOverridesQueryParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Search string
|
||||
ScopeKind *string
|
||||
VirtualKeyID *string
|
||||
ProviderID *string
|
||||
ProviderKeyID *string
|
||||
}
|
||||
|
||||
// ConfigStore is the interface for the config store.
|
||||
type ConfigStore interface {
|
||||
// Health check
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Encryption
|
||||
EncryptPlaintextRows(ctx context.Context) error
|
||||
|
||||
// Client config CRUD
|
||||
UpdateClientConfig(ctx context.Context, config *ClientConfig) error
|
||||
GetClientConfig(ctx context.Context) (*ClientConfig, error)
|
||||
|
||||
// Framework config CRUD
|
||||
UpdateFrameworkConfig(ctx context.Context, config *tables.TableFrameworkConfig) error
|
||||
GetFrameworkConfig(ctx context.Context) (*tables.TableFrameworkConfig, error)
|
||||
|
||||
// Provider config CRUD
|
||||
UpdateProvidersConfig(ctx context.Context, providers map[schemas.ModelProvider]ProviderConfig, tx ...*gorm.DB) error
|
||||
AddProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
|
||||
UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, tx ...*gorm.DB) error
|
||||
DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error
|
||||
GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error)
|
||||
GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error)
|
||||
GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error)
|
||||
GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error)
|
||||
CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error
|
||||
UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error
|
||||
DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error
|
||||
GetProviders(ctx context.Context) ([]tables.TableProvider, error)
|
||||
GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error)
|
||||
UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error
|
||||
|
||||
// MCP config CRUD
|
||||
GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error)
|
||||
GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error)
|
||||
GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error)
|
||||
GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error)
|
||||
GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error)
|
||||
CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error
|
||||
UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error
|
||||
DeleteMCPClientConfig(ctx context.Context, id string) error
|
||||
|
||||
// Vector store config CRUD
|
||||
UpdateVectorStoreConfig(ctx context.Context, config *vectorstore.Config) error
|
||||
GetVectorStoreConfig(ctx context.Context) (*vectorstore.Config, error)
|
||||
|
||||
// Logs store config CRUD
|
||||
UpdateLogsStoreConfig(ctx context.Context, config *logstore.Config) error
|
||||
GetLogsStoreConfig(ctx context.Context) (*logstore.Config, error)
|
||||
|
||||
// Config CRUD
|
||||
GetConfig(ctx context.Context, key string) (*tables.TableGovernanceConfig, error)
|
||||
UpdateConfig(ctx context.Context, config *tables.TableGovernanceConfig, tx ...*gorm.DB) error
|
||||
|
||||
// Plugins CRUD
|
||||
GetPlugins(ctx context.Context) ([]*tables.TablePlugin, error)
|
||||
GetPlugin(ctx context.Context, name string) (*tables.TablePlugin, error)
|
||||
CreatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
|
||||
UpsertPlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
|
||||
UpdatePlugin(ctx context.Context, plugin *tables.TablePlugin, tx ...*gorm.DB) error
|
||||
DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error
|
||||
|
||||
// Governance config CRUD
|
||||
GetVirtualKeys(ctx context.Context) ([]tables.TableVirtualKey, error)
|
||||
GetVirtualKeysPaginated(ctx context.Context, params VirtualKeyQueryParams) ([]tables.TableVirtualKey, int64, error)
|
||||
GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) // leave ids empty to get all
|
||||
GetVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error)
|
||||
GetVirtualKeyByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
|
||||
GetVirtualKeyQuotaByValue(ctx context.Context, value string) (*tables.TableVirtualKey, error)
|
||||
CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
|
||||
UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error
|
||||
DeleteVirtualKey(ctx context.Context, id string) error
|
||||
|
||||
// Virtual key provider config CRUD
|
||||
GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error)
|
||||
CreateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
|
||||
UpdateVirtualKeyProviderConfig(ctx context.Context, virtualKeyProviderConfig *tables.TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
|
||||
DeleteVirtualKeyProviderConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
|
||||
|
||||
// Virtual key MCP config CRUD
|
||||
GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
GetVirtualKeyMCPConfigsByMCPClientStringIDs(ctx context.Context, clientIDs []string) ([]tables.TableVirtualKeyMCPConfig, error)
|
||||
CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
|
||||
UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error
|
||||
DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error
|
||||
|
||||
// Team CRUD
|
||||
GetTeams(ctx context.Context, customerID string) ([]tables.TableTeam, error)
|
||||
GetTeamsPaginated(ctx context.Context, params TeamsQueryParams) ([]tables.TableTeam, int64, error)
|
||||
GetTeam(ctx context.Context, id string) (*tables.TableTeam, error)
|
||||
CreateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
|
||||
UpdateTeam(ctx context.Context, team *tables.TableTeam, tx ...*gorm.DB) error
|
||||
DeleteTeam(ctx context.Context, id string) error
|
||||
|
||||
// Customer CRUD
|
||||
GetCustomers(ctx context.Context) ([]tables.TableCustomer, error)
|
||||
GetCustomersPaginated(ctx context.Context, params CustomersQueryParams) ([]tables.TableCustomer, int64, error)
|
||||
GetCustomer(ctx context.Context, id string) (*tables.TableCustomer, error)
|
||||
CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
|
||||
UpdateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error
|
||||
DeleteCustomer(ctx context.Context, id string) error
|
||||
|
||||
// Rate limit CRUD
|
||||
GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error)
|
||||
GetRateLimit(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableRateLimit, error)
|
||||
CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
|
||||
UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
|
||||
UpdateRateLimits(ctx context.Context, rateLimits []*tables.TableRateLimit, tx ...*gorm.DB) error
|
||||
DeleteRateLimit(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
|
||||
// Budget CRUD
|
||||
GetBudgets(ctx context.Context) ([]tables.TableBudget, error)
|
||||
GetBudget(ctx context.Context, id string, tx ...*gorm.DB) (*tables.TableBudget, error)
|
||||
CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
|
||||
UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error
|
||||
UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error
|
||||
DeleteBudget(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error
|
||||
UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error
|
||||
|
||||
// Routing Rules CRUD
|
||||
GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error)
|
||||
GetRoutingRulesByScope(ctx context.Context, scope string, scopeID string) ([]tables.TableRoutingRule, error)
|
||||
GetRoutingRule(ctx context.Context, id string) (*tables.TableRoutingRule, error)
|
||||
GetRedactedRoutingRules(ctx context.Context, ids []string) ([]tables.TableRoutingRule, error) // leave ids empty to get all
|
||||
GetRoutingRulesPaginated(ctx context.Context, params RoutingRulesQueryParams) ([]tables.TableRoutingRule, int64, error)
|
||||
CreateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
|
||||
UpdateRoutingRule(ctx context.Context, rule *tables.TableRoutingRule, tx ...*gorm.DB) error
|
||||
DeleteRoutingRule(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
|
||||
// Model config CRUD
|
||||
GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error)
|
||||
GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error)
|
||||
GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error)
|
||||
GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error)
|
||||
CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
|
||||
UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error
|
||||
UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error
|
||||
DeleteModelConfig(ctx context.Context, id string) error
|
||||
|
||||
// Governance config CRUD
|
||||
GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error)
|
||||
|
||||
// Auth config CRUD
|
||||
GetAuthConfig(ctx context.Context) (*AuthConfig, error)
|
||||
UpdateAuthConfig(ctx context.Context, config *AuthConfig) error
|
||||
|
||||
// Proxy config CRUD
|
||||
GetProxyConfig(ctx context.Context) (*tables.GlobalProxyConfig, error)
|
||||
UpdateProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error
|
||||
|
||||
// Restart required config CRUD
|
||||
GetRestartRequiredConfig(ctx context.Context) (*tables.RestartRequiredConfig, error)
|
||||
SetRestartRequiredConfig(ctx context.Context, config *tables.RestartRequiredConfig) error
|
||||
ClearRestartRequiredConfig(ctx context.Context) error
|
||||
|
||||
// Session CRUD
|
||||
GetSession(ctx context.Context, token string) (*tables.SessionsTable, error)
|
||||
CreateSession(ctx context.Context, session *tables.SessionsTable) error
|
||||
DeleteSession(ctx context.Context, token string) error
|
||||
FlushSessions(ctx context.Context) error
|
||||
|
||||
// Model pricing CRUD
|
||||
GetModelPrices(ctx context.Context) ([]tables.TableModelPricing, error)
|
||||
UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error
|
||||
DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error
|
||||
|
||||
// Governance pricing overrides CRUD
|
||||
GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error)
|
||||
GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error)
|
||||
GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error)
|
||||
CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
|
||||
UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error
|
||||
DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error
|
||||
|
||||
// Model parameters
|
||||
GetModelParameters(ctx context.Context) ([]tables.TableModelParameters, error)
|
||||
GetModelParametersByModel(ctx context.Context, model string) (*tables.TableModelParameters, error)
|
||||
UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error
|
||||
|
||||
// Key management
|
||||
GetKeysByIDs(ctx context.Context, ids []string) ([]tables.TableKey, error)
|
||||
GetKeysByProvider(ctx context.Context, provider string) ([]tables.TableKey, error)
|
||||
GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) // leave ids empty to get all
|
||||
|
||||
// Generic transaction manager
|
||||
ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error
|
||||
|
||||
// TryAcquireLock attempts to insert a lock row. Returns true if the lock was acquired.
|
||||
// If the lock already exists and is not expired, returns false.
|
||||
TryAcquireLock(ctx context.Context, lock *tables.TableDistributedLock) (bool, error)
|
||||
|
||||
// GetLock retrieves a lock by its key. Returns nil if the lock doesn't exist.
|
||||
GetLock(ctx context.Context, lockKey string) (*tables.TableDistributedLock, error)
|
||||
|
||||
// UpdateLockExpiry updates the expiration time for an existing lock.
|
||||
// Only succeeds if the holder ID matches the current lock holder.
|
||||
UpdateLockExpiry(ctx context.Context, lockKey, holderID string, expiresAt time.Time) error
|
||||
|
||||
// ReleaseLock deletes a lock if the holder ID matches.
|
||||
// Returns true if the lock was released, false if it wasn't held by the given holder.
|
||||
ReleaseLock(ctx context.Context, lockKey, holderID string) (bool, error)
|
||||
|
||||
// CleanupExpiredLockByKey atomically deletes a specific lock only if it has expired.
|
||||
// Returns true if an expired lock was deleted, false if the lock doesn't exist or hasn't expired.
|
||||
CleanupExpiredLockByKey(ctx context.Context, lockKey string) (bool, error)
|
||||
|
||||
// CleanupExpiredLocks removes all locks that have expired.
|
||||
// Returns the number of locks cleaned up.
|
||||
CleanupExpiredLocks(ctx context.Context) (int64, error)
|
||||
|
||||
// OAuth config CRUD
|
||||
GetOauthConfigByID(ctx context.Context, id string) (*tables.TableOauthConfig, error)
|
||||
GetOauthConfigByState(ctx context.Context, state string) (*tables.TableOauthConfig, error)
|
||||
GetOauthConfigByTokenID(ctx context.Context, tokenID string) (*tables.TableOauthConfig, error)
|
||||
CreateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
|
||||
UpdateOauthConfig(ctx context.Context, config *tables.TableOauthConfig) error
|
||||
|
||||
// OAuth token CRUD
|
||||
GetOauthTokenByID(ctx context.Context, id string) (*tables.TableOauthToken, error)
|
||||
GetExpiringOauthTokens(ctx context.Context, before time.Time) ([]*tables.TableOauthToken, error)
|
||||
CreateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
|
||||
UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error
|
||||
DeleteOauthToken(ctx context.Context, id string) error
|
||||
|
||||
// Per-user OAuth session CRUD
|
||||
GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error)
|
||||
GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
|
||||
ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error)
|
||||
GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error)
|
||||
CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
|
||||
UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error
|
||||
|
||||
// Per-user OAuth token CRUD
|
||||
GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error)
|
||||
GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error)
|
||||
CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
|
||||
UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error
|
||||
DeleteOauthUserToken(ctx context.Context, id string) error
|
||||
DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error
|
||||
|
||||
// Per-user OAuth Authorization Server CRUD (Bifrost as OAuth server)
|
||||
GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error)
|
||||
CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error
|
||||
GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error)
|
||||
GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error)
|
||||
CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
|
||||
UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error
|
||||
DeletePerUserOAuthSession(ctx context.Context, id string) error
|
||||
GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
|
||||
ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error)
|
||||
CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
|
||||
UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error
|
||||
|
||||
// Per-user OAuth consent flow (pending flows before code issuance)
|
||||
GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error)
|
||||
CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
|
||||
UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error
|
||||
DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error
|
||||
// ConsumePerUserOAuthPendingFlow atomically deletes a pending flow and returns the number of
|
||||
// rows affected. Returns 0 if the flow was already consumed by a concurrent request.
|
||||
ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error)
|
||||
// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session,
|
||||
// and creates the authorization code in a single transaction. Returns (0, nil) if the
|
||||
// flow was already consumed by a concurrent request.
|
||||
FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error)
|
||||
// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID.
|
||||
// Used during consent submit to discover which MCPs the user authenticated with.
|
||||
// Queries tokens via upstream sessions matching the given gateway session ID.
|
||||
GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error)
|
||||
// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions
|
||||
// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID on each record.
|
||||
TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error
|
||||
|
||||
// Not found retry wrapper
|
||||
RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error)
|
||||
|
||||
// Prompt Repository - Folders
|
||||
GetFolders(ctx context.Context) ([]tables.TableFolder, error)
|
||||
GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error)
|
||||
CreateFolder(ctx context.Context, folder *tables.TableFolder) error
|
||||
UpdateFolder(ctx context.Context, folder *tables.TableFolder) error
|
||||
DeleteFolder(ctx context.Context, id string) error
|
||||
|
||||
// Prompt Repository - Prompts
|
||||
GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error)
|
||||
GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error)
|
||||
CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
|
||||
UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error
|
||||
DeletePrompt(ctx context.Context, id string) error
|
||||
|
||||
// Prompt Repository - Versions
|
||||
GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error)
|
||||
GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error)
|
||||
GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error)
|
||||
GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error)
|
||||
CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error
|
||||
DeletePromptVersion(ctx context.Context, id uint) error
|
||||
|
||||
// Prompt Repository - Sessions
|
||||
GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error)
|
||||
GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error)
|
||||
CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
|
||||
UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error
|
||||
RenamePromptSession(ctx context.Context, id uint, name string) error
|
||||
DeletePromptSession(ctx context.Context, id uint) error
|
||||
|
||||
// DB returns the underlying database connection.
|
||||
DB() *gorm.DB
|
||||
|
||||
// RunMigration opens a throwaway *gorm.DB against the same
|
||||
// backing database, invokes fn with it, and closes the connection. Use
|
||||
// this for DDL (typically downstream-consumer migrations) that must not
|
||||
// leave cached prepared-statement plans on the runtime pool.
|
||||
//
|
||||
// After fn returns successfully, callers should invoke
|
||||
// RefreshConnectionPool if the migration altered tables the runtime pool
|
||||
// has already queried — otherwise SQLSTATE 0A000 can surface on reads
|
||||
// whose cached plans predate the DDL.
|
||||
//
|
||||
// For SQLite backends, this is a pass-through that runs fn on the
|
||||
// existing connection (no server-side plan cache, single-writer lock).
|
||||
RunMigration(ctx context.Context, fn func(context.Context, *gorm.DB) error) error
|
||||
|
||||
// RefreshConnectionPool tears down the runtime pool and opens a fresh
|
||||
// one against the same configuration. In-flight queries on the old
|
||||
// pool complete before it closes; subsequent DB() calls return the new
|
||||
// pool, whose connections carry no cached plans. SQLite is a no-op.
|
||||
RefreshConnectionPool(ctx context.Context) error
|
||||
|
||||
// Cleanup
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
// NewConfigStore creates a new config store based on the configuration
|
||||
func NewConfigStore(ctx context.Context, config *Config, logger schemas.Logger) (ConfigStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config cannot be nil")
|
||||
}
|
||||
if !config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
switch config.Type {
|
||||
case ConfigStoreTypeSQLite:
|
||||
if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok {
|
||||
return newSqliteConfigStore(ctx, sqliteConfig, logger)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid sqlite config: %T", config.Config)
|
||||
case ConfigStoreTypePostgres:
|
||||
if postgresConfig, ok := config.Config.(*PostgresConfig); ok {
|
||||
return newPostgresConfigStore(ctx, postgresConfig, logger)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid postgres config: %T", config.Config)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported config store type: %s", config.Type)
|
||||
}
|
||||
64
framework/configstore/tables/budget.go
Normal file
64
framework/configstore/tables/budget.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableBudget defines spending limits with configurable reset periods
|
||||
type TableBudget struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
MaxLimit float64 `gorm:"not null" json:"max_limit"` // Maximum budget in dollars
|
||||
ResetDuration string `gorm:"type:varchar(50);not null" json:"reset_duration"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
|
||||
LastReset time.Time `gorm:"index" json:"last_reset"` // Last time budget was reset
|
||||
CurrentUsage float64 `gorm:"default:0" json:"current_usage"` // Current usage in dollars
|
||||
|
||||
// Owner FKs: a budget belongs to at most one Team, one VK, or one ProviderConfig
|
||||
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id,omitempty"`
|
||||
ProviderConfigID *uint `gorm:"index" json:"provider_config_id,omitempty"`
|
||||
|
||||
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableBudget) TableName() string { return "governance_budgets" }
|
||||
|
||||
// BeforeSave hook for Budget to validate reset duration format and max limit
|
||||
func (b *TableBudget) BeforeSave(tx *gorm.DB) error {
|
||||
// A budget belongs to at most one owner type
|
||||
owners := 0
|
||||
if b.TeamID != nil {
|
||||
owners++
|
||||
}
|
||||
if b.VirtualKeyID != nil {
|
||||
owners++
|
||||
}
|
||||
if b.ProviderConfigID != nil {
|
||||
owners++
|
||||
}
|
||||
if owners > 1 {
|
||||
return fmt.Errorf("budget cannot have more than one owner (team/virtual key/provider config)")
|
||||
}
|
||||
// Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y")
|
||||
if d, err := ParseDuration(b.ResetDuration); err != nil {
|
||||
return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration)
|
||||
} else if d <= 0 {
|
||||
return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration)
|
||||
}
|
||||
// Validate that MaxLimit is not negative (budgets should be positive)
|
||||
if b.MaxLimit < 0 {
|
||||
return fmt.Errorf("budget max_limit cannot be negative: %.2f", b.MaxLimit)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
187
framework/configstore/tables/clientconfig.go
Normal file
187
framework/configstore/tables/clientconfig.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableClientConfig represents global client configuration in the database
|
||||
type TableClientConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
DropExcessRequests bool `gorm:"default:false" json:"drop_excess_requests"`
|
||||
PrometheusLabelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
AllowedHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
HeaderFilterConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized GlobalHeaderFilterConfig
|
||||
InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"`
|
||||
EnableLogging *bool `gorm:"default:true" json:"enable_logging"`
|
||||
DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged
|
||||
DisableDBPingsInHealth bool `gorm:"default:false" json:"disable_db_pings_in_health"`
|
||||
LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day)
|
||||
EnforceAuthOnInference bool `gorm:"default:false" json:"enforce_auth_on_inference"`
|
||||
EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"`
|
||||
EnforceSCIMAuth bool `gorm:"default:false" json:"enforce_scim_auth"`
|
||||
AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"`
|
||||
MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"`
|
||||
MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"`
|
||||
MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30)
|
||||
MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool"
|
||||
MCPToolSyncInterval int `gorm:"default:10" json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled)
|
||||
MCPDisableAutoToolInject bool `gorm:"default:false" json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default
|
||||
AsyncJobResultTTL int `gorm:"default:3600" json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour)
|
||||
RequiredHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
LoggingHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
HideDeletedVirtualKeysInFilters bool `gorm:"default:false" json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys in logs filter dropdowns
|
||||
RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10)
|
||||
WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
|
||||
// Compat plugin feature flags
|
||||
CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"`
|
||||
CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"`
|
||||
CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"`
|
||||
CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
PrometheusLabels []string `gorm:"-" json:"prometheus_labels"`
|
||||
AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"`
|
||||
AllowedHeaders []string `gorm:"-" json:"allowed_headers,omitempty"`
|
||||
RequiredHeaders []string `gorm:"-" json:"required_headers,omitempty"`
|
||||
LoggingHeaders []string `gorm:"-" json:"logging_headers,omitempty"`
|
||||
WhitelistedRoutes []string `gorm:"-" json:"whitelisted_routes,omitempty"`
|
||||
HeaderFilterConfig *GlobalHeaderFilterConfig `gorm:"-" json:"header_filter_config,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableClientConfig) TableName() string { return "config_client" }
|
||||
|
||||
func (cc *TableClientConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if cc.PrometheusLabels != nil {
|
||||
data, err := json.Marshal(cc.PrometheusLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.PrometheusLabelsJSON = string(data)
|
||||
} else {
|
||||
cc.PrometheusLabelsJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.AllowedOrigins != nil {
|
||||
data, err := json.Marshal(cc.AllowedOrigins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.AllowedOriginsJSON = string(data)
|
||||
} else {
|
||||
cc.AllowedOriginsJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.AllowedHeaders != nil {
|
||||
data, err := json.Marshal(cc.AllowedHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.AllowedHeadersJSON = string(data)
|
||||
} else {
|
||||
cc.AllowedHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.WhitelistedRoutes != nil {
|
||||
data, err := json.Marshal(cc.WhitelistedRoutes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.WhitelistedRoutesJSON = string(data)
|
||||
} else {
|
||||
cc.WhitelistedRoutesJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.RequiredHeaders != nil {
|
||||
data, err := json.Marshal(cc.RequiredHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.RequiredHeadersJSON = string(data)
|
||||
} else {
|
||||
cc.RequiredHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.LoggingHeaders != nil {
|
||||
data, err := json.Marshal(cc.LoggingHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.LoggingHeadersJSON = string(data)
|
||||
} else {
|
||||
cc.LoggingHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if cc.HeaderFilterConfig != nil {
|
||||
data, err := json.Marshal(cc.HeaderFilterConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.HeaderFilterConfigJSON = string(data)
|
||||
} else {
|
||||
cc.HeaderFilterConfigJSON = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hooks for deserialization
|
||||
func (cc *TableClientConfig) AfterFind(tx *gorm.DB) error {
|
||||
if cc.PrometheusLabelsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.PrometheusLabelsJSON), &cc.PrometheusLabels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.AllowedOriginsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.AllowedOriginsJSON), &cc.AllowedOrigins); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.AllowedHeadersJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.AllowedHeadersJSON), &cc.AllowedHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.WhitelistedRoutesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.WhitelistedRoutesJSON), &cc.WhitelistedRoutes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.RequiredHeadersJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.RequiredHeadersJSON), &cc.RequiredHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.LoggingHeadersJSON != "" {
|
||||
if err := json.Unmarshal([]byte(cc.LoggingHeadersJSON), &cc.LoggingHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cc.HeaderFilterConfigJSON != "" {
|
||||
var headerFilterConfig GlobalHeaderFilterConfig
|
||||
if err := json.Unmarshal([]byte(cc.HeaderFilterConfigJSON), &headerFilterConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
cc.HeaderFilterConfig = &headerFilterConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
56
framework/configstore/tables/config.go
Normal file
56
framework/configstore/tables/config.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package tables
|
||||
|
||||
import "github.com/maximhq/bifrost/core/network"
|
||||
|
||||
const (
|
||||
ConfigAdminUsernameKey = "admin_username"
|
||||
ConfigAdminPasswordKey = "admin_password"
|
||||
ConfigIsAuthEnabledKey = "is_auth_enabled"
|
||||
ConfigDisableAuthOnInferenceKey = "disable_auth_on_inference"
|
||||
ConfigProxyKey = "proxy_config"
|
||||
ConfigRestartRequiredKey = "restart_required"
|
||||
ConfigHeaderFilterKey = "header_filter_config"
|
||||
)
|
||||
|
||||
// RestartRequiredConfig represents the restart required configuration
|
||||
// This is set when a config change requires a server restart to take effect
|
||||
type RestartRequiredConfig struct {
|
||||
Required bool `json:"required"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// GlobalProxyConfig represents the global proxy configuration
|
||||
type GlobalProxyConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type network.GlobalProxyType `json:"type"` // "http", "socks5", "tcp"
|
||||
URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080)
|
||||
Username string `json:"username,omitempty"` // Optional authentication username
|
||||
Password string `json:"password,omitempty"` // Optional authentication password
|
||||
NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds
|
||||
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` // Skip TLS certificate verification
|
||||
// Entity enablement flags
|
||||
EnableForSCIM bool `json:"enable_for_scim"` // Enable proxy for SCIM requests (enterprise only)
|
||||
EnableForInference bool `json:"enable_for_inference"` // Enable proxy for inference requests
|
||||
EnableForAPI bool `json:"enable_for_api"` // Enable proxy for API requests
|
||||
}
|
||||
|
||||
// GlobalHeaderFilterConfig represents global header filtering configuration
|
||||
// for headers forwarded to LLM providers via the x-bf-eh-* prefix.
|
||||
// Filter logic:
|
||||
// - If allowlist is non-empty, only headers in the allowlist are forwarded
|
||||
// - If denylist is non-empty, headers in the denylist are dropped
|
||||
// - If both are non-empty, allowlist takes precedence first, then denylist filters the result
|
||||
type GlobalHeaderFilterConfig struct {
|
||||
Allowlist []string `json:"allowlist,omitempty"` // If non-empty, only these headers are allowed
|
||||
Denylist []string `json:"denylist,omitempty"` // Headers to always block
|
||||
}
|
||||
|
||||
// TableGovernanceConfig represents generic configuration key-value pairs
|
||||
type TableGovernanceConfig struct {
|
||||
Key string `gorm:"primaryKey;type:varchar(255)" json:"key"`
|
||||
Value string `gorm:"type:text" json:"value"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableGovernanceConfig) TableName() string { return "governance_config" }
|
||||
15
framework/configstore/tables/confighash.go
Normal file
15
framework/configstore/tables/confighash.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Package tables contains the database tables for the configstore.
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableConfigHash represents the configuration hash in the database
|
||||
type TableConfigHash struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Hash string `gorm:"type:varchar(255);uniqueIndex;not null" json:"hash"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableConfigHash) TableName() string { return "config_hashes" }
|
||||
27
framework/configstore/tables/customer.go
Normal file
27
framework/configstore/tables/customer.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableCustomer represents a customer entity with budget and rate limit
|
||||
type TableCustomer struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
Budget *TableBudget `gorm:"foreignKey:BudgetID" json:"budget,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
|
||||
Teams []TableTeam `gorm:"foreignKey:CustomerID" json:"teams"`
|
||||
VirtualKeys []TableVirtualKey `gorm:"foreignKey:CustomerID" json:"virtual_keys"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableCustomer) TableName() string { return "governance_customers" }
|
||||
17
framework/configstore/tables/dlock.go
Normal file
17
framework/configstore/tables/dlock.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableDistributedLock represents a distributed lock entry in the database.
|
||||
// This table is used to implement distributed locking across multiple instances.
|
||||
type TableDistributedLock struct {
|
||||
LockKey string `gorm:"primaryKey;column:lock_key;size:255" json:"lock_key"`
|
||||
HolderID string `gorm:"column:holder_id;size:255;not null" json:"holder_id"`
|
||||
ExpiresAt time.Time `gorm:"column:expires_at;not null;index" json:"expires_at"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for the distributed lock table.
|
||||
func (TableDistributedLock) TableName() string {
|
||||
return "distributed_locks"
|
||||
}
|
||||
87
framework/configstore/tables/encryption.go
Normal file
87
framework/configstore/tables/encryption.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionStatusPlainText indicates the row's sensitive fields are stored as plaintext.
|
||||
EncryptionStatusPlainText = "plain_text"
|
||||
// EncryptionStatusEncrypted indicates the row's sensitive fields have been encrypted.
|
||||
EncryptionStatusEncrypted = "encrypted"
|
||||
)
|
||||
|
||||
// encryptEnvVar encrypts the Val field of an EnvVar in place using AES-256-GCM.
|
||||
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
|
||||
func encryptEnvVar(field *schemas.EnvVar) error {
|
||||
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
|
||||
return nil
|
||||
}
|
||||
encrypted, err := encrypt.Encrypt(field.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.Val = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptEnvVar decrypts the Val field of an EnvVar in place using AES-256-GCM.
|
||||
// It is a no-op if the field is nil, references an environment variable, or has an empty value.
|
||||
func decryptEnvVar(field *schemas.EnvVar) error {
|
||||
if field == nil || field.IsFromEnv() || field.GetValue() == "" {
|
||||
return nil
|
||||
}
|
||||
decrypted, err := encrypt.Decrypt(field.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.Val = decrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptEnvVarPtr encrypts the Val field of a pointer-to-EnvVar in place.
|
||||
// It is a no-op if the pointer or the EnvVar it points to is nil.
|
||||
func encryptEnvVarPtr(field **schemas.EnvVar) error {
|
||||
if field == nil || *field == nil {
|
||||
return nil
|
||||
}
|
||||
return encryptEnvVar(*field)
|
||||
}
|
||||
|
||||
// decryptEnvVarPtr decrypts the Val field of a pointer-to-EnvVar in place.
|
||||
// It is a no-op if the pointer or the EnvVar it points to is nil.
|
||||
func decryptEnvVarPtr(field **schemas.EnvVar) error {
|
||||
if field == nil || *field == nil {
|
||||
return nil
|
||||
}
|
||||
return decryptEnvVar(*field)
|
||||
}
|
||||
|
||||
// encryptString encrypts the string pointed to by value in place using AES-256-GCM.
|
||||
// It is a no-op if the pointer is nil or the string is empty.
|
||||
func encryptString(value *string) error {
|
||||
if value == nil || *value == "" {
|
||||
return nil
|
||||
}
|
||||
encrypted, err := encrypt.Encrypt(*value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*value = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptString decrypts the string pointed to by value in place using AES-256-GCM.
|
||||
// It is a no-op if the pointer is nil or the string is empty.
|
||||
func decryptString(value *string) error {
|
||||
if value == nil || *value == "" {
|
||||
return nil
|
||||
}
|
||||
decrypted, err := encrypt.Decrypt(*value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*value = decrypted
|
||||
return nil
|
||||
}
|
||||
1982
framework/configstore/tables/encryption_test.go
Normal file
1982
framework/configstore/tables/encryption_test.go
Normal file
File diff suppressed because it is too large
Load Diff
17
framework/configstore/tables/env.go
Normal file
17
framework/configstore/tables/env.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableEnvKey represents environment variable tracking in the database
|
||||
type TableEnvKey struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
EnvVar string `gorm:"type:varchar(255);index;not null" json:"env_var"`
|
||||
Provider string `gorm:"type:varchar(50);index" json:"provider"` // Empty for MCP/client configs
|
||||
KeyType string `gorm:"type:varchar(50);not null" json:"key_type"` // "api_key", "azure_config", "vertex_config", "bedrock_config", "connection_string"
|
||||
ConfigPath string `gorm:"type:varchar(500);not null" json:"config_path"` // Descriptive path of where this env var is used
|
||||
KeyID string `gorm:"type:varchar(255);index" json:"key_id"` // Key UUID (empty for non-key configs)
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableEnvKey) TableName() string { return "config_env_keys" }
|
||||
22
framework/configstore/tables/folders.go
Normal file
22
framework/configstore/tables/folders.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TableFolder represents a generic folder that can contain prompts
|
||||
type TableFolder struct {
|
||||
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
Description *string `gorm:"type:text" json:"description,omitempty"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
|
||||
|
||||
// Virtual fields (not stored in DB)
|
||||
PromptsCount int `gorm:"-" json:"prompts_count,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TableFolder
|
||||
func (TableFolder) TableName() string { return "folders" }
|
||||
12
framework/configstore/tables/framework.go
Normal file
12
framework/configstore/tables/framework.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package tables
|
||||
|
||||
// TableFrameworkConfig represents the framework configurations
|
||||
// We will keep on adding different columns here as we add new features to the framework
|
||||
type TableFrameworkConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PricingURL *string `gorm:"type:text" json:"pricing_url"`
|
||||
PricingSyncInterval *int64 `gorm:"" json:"pricing_sync_interval"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableFrameworkConfig) TableName() string { return "framework_configs" }
|
||||
644
framework/configstore/tables/key.go
Normal file
644
framework/configstore/tables/key.go
Normal file
@@ -0,0 +1,644 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableKey represents an API key configuration in the database
|
||||
type TableKey struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);uniqueIndex:idx_key_name;not null" json:"name"`
|
||||
ProviderID uint `gorm:"index;not null" json:"provider_id"`
|
||||
Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string
|
||||
KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key
|
||||
Value schemas.EnvVar `gorm:"type:text;not null" json:"value"`
|
||||
ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
BlacklistedModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
Weight *float64 `json:"weight"`
|
||||
Enabled *bool `gorm:"default:true" json:"enabled,omitempty"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Config hash is used to detect changes synced from config.json file
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
// Unified aliases
|
||||
AliasesJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.KeyAliases
|
||||
|
||||
// Azure config fields (embedded instead of separate table for simplicity)
|
||||
AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"`
|
||||
AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"`
|
||||
AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"`
|
||||
AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"`
|
||||
AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"`
|
||||
AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string
|
||||
|
||||
// Vertex config fields (embedded)
|
||||
VertexProjectID *schemas.EnvVar `gorm:"type:text" json:"vertex_project_id,omitempty"`
|
||||
VertexProjectNumber *schemas.EnvVar `gorm:"type:text" json:"vertex_project_number,omitempty"`
|
||||
VertexRegion *schemas.EnvVar `gorm:"type:text" json:"vertex_region,omitempty"`
|
||||
VertexAuthCredentials *schemas.EnvVar `gorm:"type:text" json:"vertex_auth_credentials,omitempty"`
|
||||
|
||||
// Bedrock config fields (embedded)
|
||||
BedrockAccessKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_access_key,omitempty"`
|
||||
BedrockSecretKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_secret_key,omitempty"`
|
||||
BedrockSessionToken *schemas.EnvVar `gorm:"type:text" json:"bedrock_session_token,omitempty"`
|
||||
BedrockRegion *schemas.EnvVar `gorm:"type:text" json:"bedrock_region,omitempty"`
|
||||
BedrockARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_arn,omitempty"`
|
||||
BedrockRoleARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_arn,omitempty"`
|
||||
BedrockExternalID *schemas.EnvVar `gorm:"type:text" json:"bedrock_external_id,omitempty"`
|
||||
BedrockRoleSessionName *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_session_name,omitempty"`
|
||||
BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config
|
||||
|
||||
// VLLM config fields (embedded)
|
||||
VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"`
|
||||
VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"`
|
||||
|
||||
// Replicate config fields (embedded)
|
||||
ReplicateUseDeploymentsEndpoint *bool `gorm:"column:replicate_use_deployments_endpoint" json:"replicate_use_deployments_endpoint,omitempty"`
|
||||
|
||||
// Ollama config fields (embedded)
|
||||
OllamaUrl *schemas.EnvVar `gorm:"type:text" json:"ollama_url,omitempty"`
|
||||
|
||||
// SGL config fields (embedded)
|
||||
SGLUrl *schemas.EnvVar `gorm:"type:text" json:"sgl_url,omitempty"`
|
||||
|
||||
// Batch API configuration
|
||||
UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations
|
||||
|
||||
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
|
||||
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default)
|
||||
BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"`
|
||||
Aliases schemas.KeyAliases `gorm:"-" json:"aliases,omitempty"`
|
||||
AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"`
|
||||
VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"`
|
||||
BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"`
|
||||
VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"`
|
||||
ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"`
|
||||
OllamaKeyConfig *schemas.OllamaKeyConfig `gorm:"-" json:"ollama_key_config,omitempty"`
|
||||
SGLKeyConfig *schemas.SGLKeyConfig `gorm:"-" json:"sgl_key_config,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableKey) TableName() string { return "config_keys" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns and
|
||||
// encrypts sensitive fields (API key value, Azure endpoint/client ID/secret/tenant ID/API version,
|
||||
// Vertex project ID/project number/region/credentials, Bedrock keys/region/ARN/deployments/
|
||||
// batch S3 config) before writing to the database. Encryption runs last to ensure it
|
||||
// operates on the final serialized values.
|
||||
func (k *TableKey) BeforeSave(tx *gorm.DB) error {
|
||||
if err := k.Models.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(k.Models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.ModelsJSON = string(data)
|
||||
if err := k.BlacklistedModels.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err = json.Marshal(k.BlacklistedModels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.BlacklistedModelsJSON = string(data)
|
||||
if k.Enabled == nil {
|
||||
enabled := true // DB default
|
||||
k.Enabled = &enabled
|
||||
}
|
||||
if k.UseForBatchAPI == nil {
|
||||
useForBatchAPI := false // DB default
|
||||
k.UseForBatchAPI = &useForBatchAPI
|
||||
}
|
||||
// IMPORTANT: All *EnvVar fields assigned from provider config structs (AzureKeyConfig,
|
||||
// VertexKeyConfig, BedrockKeyConfig) MUST be value-copied before assignment. The caller
|
||||
// may retain the config struct pointer; if BeforeSave (or future encryption) mutates a
|
||||
// shared pointer, the caller's in-memory config is silently corrupted.
|
||||
// See: TestBeforeSave_DoesNotMutateSharedProviderConfigs
|
||||
if k.AzureKeyConfig != nil {
|
||||
if k.AzureKeyConfig.Endpoint.IsSet() {
|
||||
ep := k.AzureKeyConfig.Endpoint
|
||||
k.AzureEndpoint = &ep
|
||||
} else {
|
||||
k.AzureEndpoint = nil
|
||||
}
|
||||
if k.AzureKeyConfig.APIVersion != nil {
|
||||
av := *k.AzureKeyConfig.APIVersion
|
||||
k.AzureAPIVersion = &av
|
||||
} else {
|
||||
k.AzureAPIVersion = nil
|
||||
}
|
||||
if k.AzureKeyConfig.ClientID != nil {
|
||||
cid := *k.AzureKeyConfig.ClientID
|
||||
k.AzureClientID = &cid
|
||||
} else {
|
||||
k.AzureClientID = nil
|
||||
}
|
||||
if k.AzureKeyConfig.ClientSecret != nil {
|
||||
cs := *k.AzureKeyConfig.ClientSecret
|
||||
k.AzureClientSecret = &cs
|
||||
} else {
|
||||
k.AzureClientSecret = nil
|
||||
}
|
||||
if k.AzureKeyConfig.TenantID != nil {
|
||||
tid := *k.AzureKeyConfig.TenantID
|
||||
k.AzureTenantID = &tid
|
||||
} else {
|
||||
k.AzureTenantID = nil
|
||||
}
|
||||
if len(k.AzureKeyConfig.Scopes) > 0 {
|
||||
data, err := json.Marshal(k.AzureKeyConfig.Scopes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := string(data)
|
||||
k.AzureScopesJSON = &s
|
||||
} else {
|
||||
k.AzureScopesJSON = nil
|
||||
}
|
||||
} else {
|
||||
k.AzureEndpoint = nil
|
||||
k.AzureAPIVersion = nil
|
||||
k.AzureClientID = nil
|
||||
k.AzureClientSecret = nil
|
||||
k.AzureTenantID = nil
|
||||
k.AzureScopesJSON = nil
|
||||
}
|
||||
if k.VertexKeyConfig != nil {
|
||||
if k.VertexKeyConfig.ProjectID.IsSet() {
|
||||
pid := k.VertexKeyConfig.ProjectID
|
||||
k.VertexProjectID = &pid
|
||||
} else {
|
||||
k.VertexProjectID = nil
|
||||
}
|
||||
if k.VertexKeyConfig.ProjectNumber.IsSet() {
|
||||
pn := k.VertexKeyConfig.ProjectNumber
|
||||
k.VertexProjectNumber = &pn
|
||||
} else {
|
||||
k.VertexProjectNumber = nil
|
||||
}
|
||||
if k.VertexKeyConfig.Region.IsSet() {
|
||||
vr := k.VertexKeyConfig.Region
|
||||
k.VertexRegion = &vr
|
||||
} else {
|
||||
k.VertexRegion = nil
|
||||
}
|
||||
if k.VertexKeyConfig.AuthCredentials.IsSet() {
|
||||
ac := k.VertexKeyConfig.AuthCredentials
|
||||
k.VertexAuthCredentials = &ac
|
||||
} else {
|
||||
k.VertexAuthCredentials = nil
|
||||
}
|
||||
} else {
|
||||
k.VertexProjectID = nil
|
||||
k.VertexProjectNumber = nil
|
||||
k.VertexRegion = nil
|
||||
k.VertexAuthCredentials = nil
|
||||
}
|
||||
if k.BedrockKeyConfig != nil {
|
||||
if k.BedrockKeyConfig.AccessKey.IsSet() {
|
||||
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
|
||||
ak := k.BedrockKeyConfig.AccessKey
|
||||
k.BedrockAccessKey = &ak
|
||||
} else {
|
||||
k.BedrockAccessKey = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.SecretKey.IsSet() {
|
||||
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
|
||||
sk := k.BedrockKeyConfig.SecretKey
|
||||
k.BedrockSecretKey = &sk
|
||||
} else {
|
||||
k.BedrockSecretKey = nil
|
||||
}
|
||||
// Copy to avoid encrypting the shared BedrockKeyConfig through the pointer
|
||||
if k.BedrockKeyConfig.SessionToken != nil {
|
||||
st := *k.BedrockKeyConfig.SessionToken
|
||||
k.BedrockSessionToken = &st
|
||||
} else {
|
||||
k.BedrockSessionToken = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.Region != nil {
|
||||
br := *k.BedrockKeyConfig.Region
|
||||
k.BedrockRegion = &br
|
||||
} else {
|
||||
k.BedrockRegion = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.ARN != nil {
|
||||
ba := *k.BedrockKeyConfig.ARN
|
||||
k.BedrockARN = &ba
|
||||
} else {
|
||||
k.BedrockARN = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.RoleARN != nil {
|
||||
bra := *k.BedrockKeyConfig.RoleARN
|
||||
k.BedrockRoleARN = &bra
|
||||
} else {
|
||||
k.BedrockRoleARN = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.ExternalID != nil {
|
||||
ei := *k.BedrockKeyConfig.ExternalID
|
||||
k.BedrockExternalID = &ei
|
||||
} else {
|
||||
k.BedrockExternalID = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.RoleSessionName != nil {
|
||||
rsn := *k.BedrockKeyConfig.RoleSessionName
|
||||
k.BedrockRoleSessionName = &rsn
|
||||
} else {
|
||||
k.BedrockRoleSessionName = nil
|
||||
}
|
||||
if k.BedrockKeyConfig.BatchS3Config != nil {
|
||||
data, err := sonic.Marshal(k.BedrockKeyConfig.BatchS3Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := string(data)
|
||||
k.BedrockBatchS3ConfigJSON = &s
|
||||
} else {
|
||||
k.BedrockBatchS3ConfigJSON = nil
|
||||
}
|
||||
} else {
|
||||
k.BedrockAccessKey = nil
|
||||
k.BedrockSecretKey = nil
|
||||
k.BedrockSessionToken = nil
|
||||
k.BedrockRegion = nil
|
||||
k.BedrockARN = nil
|
||||
k.BedrockRoleARN = nil
|
||||
k.BedrockExternalID = nil
|
||||
k.BedrockRoleSessionName = nil
|
||||
k.BedrockBatchS3ConfigJSON = nil
|
||||
}
|
||||
|
||||
if k.Aliases != nil {
|
||||
if err := k.Aliases.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := sonic.Marshal(k.Aliases)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := string(data)
|
||||
k.AliasesJSON = &s
|
||||
} else {
|
||||
k.AliasesJSON = nil
|
||||
}
|
||||
|
||||
if k.VLLMKeyConfig != nil {
|
||||
if k.VLLMKeyConfig.URL.IsSet() {
|
||||
u := k.VLLMKeyConfig.URL // Value-copy to prevent shared pointer mutation
|
||||
k.VLLMUrl = &u
|
||||
} else {
|
||||
k.VLLMUrl = nil
|
||||
}
|
||||
if k.VLLMKeyConfig.ModelName != "" {
|
||||
mn := k.VLLMKeyConfig.ModelName
|
||||
k.VLLMModelName = &mn
|
||||
} else {
|
||||
k.VLLMModelName = nil
|
||||
}
|
||||
} else {
|
||||
k.VLLMUrl = nil
|
||||
k.VLLMModelName = nil
|
||||
}
|
||||
|
||||
if k.ReplicateKeyConfig != nil {
|
||||
v := k.ReplicateKeyConfig.UseDeploymentsEndpoint
|
||||
k.ReplicateUseDeploymentsEndpoint = &v
|
||||
} else {
|
||||
k.ReplicateUseDeploymentsEndpoint = nil
|
||||
}
|
||||
|
||||
if k.OllamaKeyConfig != nil && k.OllamaKeyConfig.URL.IsSet() {
|
||||
u := k.OllamaKeyConfig.URL
|
||||
k.OllamaUrl = &u
|
||||
} else {
|
||||
k.OllamaUrl = nil
|
||||
}
|
||||
|
||||
if k.SGLKeyConfig != nil && k.SGLKeyConfig.URL.IsSet() {
|
||||
u := k.SGLKeyConfig.URL
|
||||
k.SGLUrl = &u
|
||||
} else {
|
||||
k.SGLUrl = nil
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields after serialization
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptEnvVar(&k.Value); err != nil {
|
||||
return fmt.Errorf("failed to encrypt key value: %w", err)
|
||||
}
|
||||
// Azure
|
||||
if err := encryptEnvVarPtr(&k.AzureEndpoint); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure endpoint: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureClientID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure client id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure client secret: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureTenantID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure tenant id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
|
||||
return fmt.Errorf("failed to encrypt azure api version: %w", err)
|
||||
}
|
||||
// Vertex
|
||||
if err := encryptEnvVarPtr(&k.VertexProjectID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex project id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex project number: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.VertexRegion); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex region: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vertex auth credentials: %w", err)
|
||||
}
|
||||
// Bedrock
|
||||
if err := encryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock access key: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock secret key: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock session token: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockRegion); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock region: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockARN); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock arn: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock role arn: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockExternalID); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock external id: %w", err)
|
||||
}
|
||||
if err := encryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock role session name: %w", err)
|
||||
}
|
||||
if err := encryptString(k.BedrockBatchS3ConfigJSON); err != nil {
|
||||
return fmt.Errorf("failed to encrypt bedrock batch s3 config: %w", err)
|
||||
}
|
||||
// Aliases
|
||||
if err := encryptString(k.AliasesJSON); err != nil {
|
||||
return fmt.Errorf("failed to encrypt aliases: %w", err)
|
||||
}
|
||||
// VLLM
|
||||
if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vllm url: %w", err)
|
||||
}
|
||||
// Ollama
|
||||
if err := encryptEnvVarPtr(&k.OllamaUrl); err != nil {
|
||||
return fmt.Errorf("failed to encrypt ollama url: %w", err)
|
||||
}
|
||||
// SGL
|
||||
if err := encryptEnvVarPtr(&k.SGLUrl); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sgl url: %w", err)
|
||||
}
|
||||
k.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts sensitive fields and reconstructs runtime config
|
||||
// structs after reading from the database. Decryption runs first so that value copies into
|
||||
// AzureKeyConfig, VertexKeyConfig, etc. receive plaintext data.
|
||||
func (k *TableKey) AfterFind(tx *gorm.DB) error {
|
||||
// Decrypt sensitive fields before deserialization/reconstruction
|
||||
if k.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptEnvVar(&k.Value); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key value: %w", err)
|
||||
}
|
||||
// Azure
|
||||
if err := decryptEnvVarPtr(&k.AzureEndpoint); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure endpoint: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureClientID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure client id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure client secret: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureTenantID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure tenant id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.AzureAPIVersion); err != nil {
|
||||
return fmt.Errorf("failed to decrypt azure api version: %w", err)
|
||||
}
|
||||
// Vertex
|
||||
if err := decryptEnvVarPtr(&k.VertexProjectID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex project id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.VertexProjectNumber); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex project number: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.VertexRegion); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex region: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.VertexAuthCredentials); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vertex auth credentials: %w", err)
|
||||
}
|
||||
// Bedrock
|
||||
if err := decryptEnvVarPtr(&k.BedrockAccessKey); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock access key: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockSecretKey); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock secret key: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockSessionToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock session token: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockRegion); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock region: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockARN); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock arn: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockRoleARN); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock role arn: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockExternalID); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock external id: %w", err)
|
||||
}
|
||||
if err := decryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock role session name: %w", err)
|
||||
}
|
||||
if err := decryptString(k.BedrockBatchS3ConfigJSON); err != nil {
|
||||
return fmt.Errorf("failed to decrypt bedrock batch s3 config: %w", err)
|
||||
}
|
||||
// Aliases
|
||||
if err := decryptString(k.AliasesJSON); err != nil {
|
||||
return fmt.Errorf("failed to decrypt aliases: %w", err)
|
||||
}
|
||||
// VLLM
|
||||
if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vllm url: %w", err)
|
||||
}
|
||||
// Ollama
|
||||
if err := decryptEnvVarPtr(&k.OllamaUrl); err != nil {
|
||||
return fmt.Errorf("failed to decrypt ollama url: %w", err)
|
||||
}
|
||||
// SGL
|
||||
if err := decryptEnvVarPtr(&k.SGLUrl); err != nil {
|
||||
return fmt.Errorf("failed to decrypt sgl url: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if k.ModelsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(k.ModelsJSON), &k.Models); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if k.BlacklistedModelsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if k.Enabled == nil {
|
||||
enabled := true // DB default
|
||||
k.Enabled = &enabled
|
||||
}
|
||||
if k.UseForBatchAPI == nil {
|
||||
useForBatchAPI := false // DB default
|
||||
k.UseForBatchAPI = &useForBatchAPI
|
||||
}
|
||||
// Reconstruct Azure config if fields are present
|
||||
if k.AzureEndpoint != nil || k.AzureAPIVersion != nil || k.AzureClientID != nil || k.AzureClientSecret != nil || k.AzureTenantID != nil || (k.AzureScopesJSON != nil && *k.AzureScopesJSON != "") {
|
||||
var scopes []string
|
||||
if k.AzureScopesJSON != nil && *k.AzureScopesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(*k.AzureScopesJSON), &scopes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
azureConfig := &schemas.AzureKeyConfig{
|
||||
Endpoint: *schemas.NewEnvVar(""),
|
||||
APIVersion: k.AzureAPIVersion,
|
||||
ClientID: k.AzureClientID,
|
||||
ClientSecret: k.AzureClientSecret,
|
||||
TenantID: k.AzureTenantID,
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
if k.AzureEndpoint != nil {
|
||||
azureConfig.Endpoint = *k.AzureEndpoint
|
||||
}
|
||||
|
||||
k.AzureKeyConfig = azureConfig
|
||||
}
|
||||
// Reconstruct Vertex config if fields are present
|
||||
if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil {
|
||||
config := &schemas.VertexKeyConfig{}
|
||||
|
||||
if k.VertexProjectID != nil {
|
||||
config.ProjectID = *k.VertexProjectID
|
||||
}
|
||||
|
||||
if k.VertexProjectNumber != nil {
|
||||
config.ProjectNumber = *k.VertexProjectNumber
|
||||
}
|
||||
|
||||
if k.VertexRegion != nil {
|
||||
config.Region = *k.VertexRegion
|
||||
}
|
||||
if k.VertexAuthCredentials != nil {
|
||||
config.AuthCredentials = *k.VertexAuthCredentials
|
||||
}
|
||||
k.VertexKeyConfig = config
|
||||
}
|
||||
// Reconstruct Bedrock config if fields are present
|
||||
if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") {
|
||||
bedrockConfig := &schemas.BedrockKeyConfig{}
|
||||
|
||||
if k.BedrockAccessKey != nil {
|
||||
bedrockConfig.AccessKey = *k.BedrockAccessKey
|
||||
}
|
||||
|
||||
bedrockConfig.SessionToken = k.BedrockSessionToken
|
||||
bedrockConfig.Region = k.BedrockRegion
|
||||
bedrockConfig.ARN = k.BedrockARN
|
||||
bedrockConfig.RoleARN = k.BedrockRoleARN
|
||||
bedrockConfig.ExternalID = k.BedrockExternalID
|
||||
bedrockConfig.RoleSessionName = k.BedrockRoleSessionName
|
||||
|
||||
if k.BedrockSecretKey != nil {
|
||||
bedrockConfig.SecretKey = *k.BedrockSecretKey
|
||||
}
|
||||
|
||||
if k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "" {
|
||||
var batchS3Config schemas.BatchS3Config
|
||||
if err := json.Unmarshal([]byte(*k.BedrockBatchS3ConfigJSON), &batchS3Config); err != nil {
|
||||
return err
|
||||
}
|
||||
bedrockConfig.BatchS3Config = &batchS3Config
|
||||
}
|
||||
|
||||
k.BedrockKeyConfig = bedrockConfig
|
||||
}
|
||||
// Reconstruct Aliases
|
||||
if k.AliasesJSON != nil && *k.AliasesJSON != "" {
|
||||
var aliases schemas.KeyAliases
|
||||
if err := sonic.Unmarshal([]byte(*k.AliasesJSON), &aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
k.Aliases = aliases
|
||||
} else {
|
||||
k.Aliases = nil
|
||||
}
|
||||
// Reconstruct VLLM config if fields are present
|
||||
if k.VLLMUrl != nil || (k.VLLMModelName != nil && *k.VLLMModelName != "") {
|
||||
vllmConfig := &schemas.VLLMKeyConfig{}
|
||||
if k.VLLMUrl != nil {
|
||||
vllmConfig.URL = *k.VLLMUrl
|
||||
}
|
||||
if k.VLLMModelName != nil {
|
||||
vllmConfig.ModelName = *k.VLLMModelName
|
||||
}
|
||||
k.VLLMKeyConfig = vllmConfig
|
||||
} else {
|
||||
k.VLLMKeyConfig = nil
|
||||
}
|
||||
// Reconstruct Replicate config if fields are present
|
||||
if k.ReplicateUseDeploymentsEndpoint != nil {
|
||||
k.ReplicateKeyConfig = &schemas.ReplicateKeyConfig{
|
||||
UseDeploymentsEndpoint: *k.ReplicateUseDeploymentsEndpoint,
|
||||
}
|
||||
} else {
|
||||
k.ReplicateKeyConfig = nil
|
||||
}
|
||||
// Reconstruct Ollama config if fields are present
|
||||
if k.OllamaUrl != nil {
|
||||
k.OllamaKeyConfig = &schemas.OllamaKeyConfig{
|
||||
URL: *k.OllamaUrl,
|
||||
}
|
||||
} else {
|
||||
k.OllamaKeyConfig = nil
|
||||
}
|
||||
// Reconstruct SGL config if fields are present
|
||||
if k.SGLUrl != nil {
|
||||
k.SGLKeyConfig = &schemas.SGLKeyConfig{
|
||||
URL: *k.SGLUrl,
|
||||
}
|
||||
} else {
|
||||
k.SGLKeyConfig = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
16
framework/configstore/tables/logstore.go
Normal file
16
framework/configstore/tables/logstore.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableLogStoreConfig represents the configuration for the log store in the database
|
||||
type TableLogStoreConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Type string `gorm:"type:varchar(50);not null" json:"type"` // "sqlite"
|
||||
Config *string `gorm:"type:text" json:"config"` // JSON serialized logstore.Config
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableLogStoreConfig) TableName() string { return "config_log_store" }
|
||||
252
framework/configstore/tables/mcp.go
Normal file
252
framework/configstore/tables/mcp.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableMCPClient represents an MCP client configuration in the database
|
||||
type TableMCPClient struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present.
|
||||
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
|
||||
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
|
||||
IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client
|
||||
ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType
|
||||
ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"`
|
||||
StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig
|
||||
ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
|
||||
AllowedExtraHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string
|
||||
IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks
|
||||
ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64
|
||||
ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled)
|
||||
|
||||
// Per-user OAuth: discovered tools persisted so they survive restart
|
||||
DiscoveredToolsJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]schemas.ChatTool
|
||||
ToolNameMappingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string
|
||||
|
||||
// OAuth authentication fields
|
||||
AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth"
|
||||
OauthConfigID *string `gorm:"type:varchar(255);index;constraint:OnDelete:CASCADE" json:"oauth_config_id"` // Foreign key to oauth_configs.ID with CASCADE delete
|
||||
OauthConfig *TableOauthConfig `gorm:"foreignKey:OauthConfigID;references:ID;constraint:OnDelete:CASCADE" json:"-"` // Gorm relationship
|
||||
|
||||
AllowOnAllVirtualKeys bool `gorm:"default:false" json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"`
|
||||
ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"`
|
||||
ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"`
|
||||
Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"`
|
||||
AllowedExtraHeaders schemas.WhiteList `gorm:"-" json:"allowed_extra_headers"`
|
||||
ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"`
|
||||
DiscoveredTools map[string]schemas.ChatTool `gorm:"-" json:"-"`
|
||||
DiscoveredToolNameMapping map[string]string `gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableMCPClient) TableName() string { return "config_mcp_clients" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes runtime fields (stdio config, tools, headers,
|
||||
// pricing) into JSON columns and encrypts the connection string and headers before writing
|
||||
// to the database. Environment-variable-backed connection strings are not encrypted.
|
||||
func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error {
|
||||
if c.StdioConfig != nil {
|
||||
data, err := json.Marshal(c.StdioConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := string(data)
|
||||
c.StdioConfigJSON = &config
|
||||
} else {
|
||||
c.StdioConfigJSON = nil
|
||||
}
|
||||
|
||||
if c.ToolsToExecute != nil {
|
||||
if err := c.ToolsToExecute.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid tools_to_execute: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(c.ToolsToExecute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolsToExecuteJSON = string(data)
|
||||
} else {
|
||||
c.ToolsToExecuteJSON = "[]"
|
||||
}
|
||||
|
||||
if c.ToolsToAutoExecute != nil {
|
||||
if err := c.ToolsToAutoExecute.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid tools_to_auto_execute: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(c.ToolsToAutoExecute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolsToAutoExecuteJSON = string(data)
|
||||
} else {
|
||||
c.ToolsToAutoExecuteJSON = "[]"
|
||||
}
|
||||
|
||||
if c.Headers != nil {
|
||||
headersToSerialize := make(map[string]string, len(c.Headers))
|
||||
for key, value := range c.Headers {
|
||||
if value.IsFromEnv() {
|
||||
headersToSerialize[key] = value.EnvVar
|
||||
} else {
|
||||
headersToSerialize[key] = value.GetValue()
|
||||
}
|
||||
}
|
||||
data, err := json.Marshal(headersToSerialize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.HeadersJSON = string(data)
|
||||
} else {
|
||||
c.HeadersJSON = "{}"
|
||||
}
|
||||
|
||||
if c.AllowedExtraHeaders != nil {
|
||||
if err := c.AllowedExtraHeaders.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid allowed_extra_headers: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(c.AllowedExtraHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.AllowedExtraHeadersJSON = string(data)
|
||||
} else {
|
||||
c.AllowedExtraHeadersJSON = "[]"
|
||||
}
|
||||
|
||||
if c.ToolPricing != nil {
|
||||
data, err := json.Marshal(c.ToolPricing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolPricingJSON = string(data)
|
||||
} else {
|
||||
c.ToolPricingJSON = "{}"
|
||||
}
|
||||
|
||||
if c.DiscoveredTools != nil {
|
||||
data, err := json.Marshal(c.DiscoveredTools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.DiscoveredToolsJSON = string(data)
|
||||
}
|
||||
|
||||
if c.DiscoveredToolNameMapping != nil {
|
||||
data, err := json.Marshal(c.DiscoveredToolNameMapping)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ToolNameMappingJSON = string(data)
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields after serialization.
|
||||
// Always set EncryptionStatus when encryption is enabled so the startup
|
||||
// batch pass does not re-process this row indefinitely.
|
||||
if encrypt.IsEnabled() {
|
||||
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
|
||||
// Copy to avoid encrypting the shared ConnectionString through the pointer
|
||||
cs := *c.ConnectionString
|
||||
enc, err := encrypt.Encrypt(cs.Val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt mcp connection string: %w", err)
|
||||
}
|
||||
cs.Val = enc
|
||||
c.ConnectionString = &cs
|
||||
}
|
||||
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
|
||||
enc, err := encrypt.Encrypt(c.HeadersJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt mcp headers: %w", err)
|
||||
}
|
||||
c.HeadersJSON = enc
|
||||
}
|
||||
c.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the connection string and headers (if encrypted)
|
||||
// and deserializes JSON columns back into runtime structs after reading from the database.
|
||||
func (c *TableMCPClient) AfterFind(tx *gorm.DB) error {
|
||||
if c.EncryptionStatus == "encrypted" {
|
||||
if c.HeadersJSON != "" && c.HeadersJSON != "{}" {
|
||||
decrypted, err := encrypt.Decrypt(c.HeadersJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt mcp headers: %w", err)
|
||||
}
|
||||
c.HeadersJSON = decrypted
|
||||
}
|
||||
if c.ConnectionString != nil && !c.ConnectionString.IsFromEnv() && c.ConnectionString.GetValue() != "" {
|
||||
decrypted, err := encrypt.Decrypt(c.ConnectionString.Val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt mcp connection string: %w", err)
|
||||
}
|
||||
c.ConnectionString.Val = decrypted
|
||||
}
|
||||
}
|
||||
if c.StdioConfigJSON != nil {
|
||||
var config schemas.MCPStdioConfig
|
||||
if err := sonic.Unmarshal([]byte(*c.StdioConfigJSON), &config); err != nil {
|
||||
return err
|
||||
}
|
||||
c.StdioConfig = &config
|
||||
}
|
||||
if c.ToolsToExecuteJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.ToolsToExecuteJSON), &c.ToolsToExecute); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ToolsToAutoExecuteJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.HeadersJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.AllowedExtraHeadersJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.AllowedExtraHeadersJSON), &c.AllowedExtraHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ToolPricingJSON != "" {
|
||||
if err := json.Unmarshal([]byte(c.ToolPricingJSON), &c.ToolPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.DiscoveredToolsJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.DiscoveredToolsJSON), &c.DiscoveredTools); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ToolNameMappingJSON != "" {
|
||||
if err := sonic.Unmarshal([]byte(c.ToolNameMappingJSON), &c.DiscoveredToolNameMapping); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
15
framework/configstore/tables/model.go
Normal file
15
framework/configstore/tables/model.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package tables
|
||||
|
||||
import "time"
|
||||
|
||||
// TableModel represents a model configuration in the database
|
||||
type TableModel struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID uint `gorm:"index;not null;uniqueIndex:idx_provider_name" json:"provider_id"`
|
||||
Name string `gorm:"uniqueIndex:idx_provider_name" json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableModel) TableName() string { return "config_models" }
|
||||
59
framework/configstore/tables/modelconfig.go
Normal file
59
framework/configstore/tables/modelconfig.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableModelConfig represents a model configuration with rate limiting and budgeting
|
||||
type TableModelConfig struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
ModelName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider" json:"model_name"`
|
||||
Provider *string `gorm:"type:varchar(50);uniqueIndex:idx_model_provider" json:"provider,omitempty"` // Optional provider, nullable
|
||||
BudgetID *string `gorm:"type:varchar(255);index:idx_model_config_budget" json:"budget_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index:idx_model_config_rate_limit" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableModelConfig) TableName() string {
|
||||
return "governance_model_configs"
|
||||
}
|
||||
|
||||
// BeforeSave hook for ModelConfig to validate required fields
|
||||
func (mc *TableModelConfig) BeforeSave(tx *gorm.DB) error {
|
||||
// Validate that ModelName is not empty
|
||||
if strings.TrimSpace(mc.ModelName) == "" {
|
||||
return fmt.Errorf("model_name cannot be empty")
|
||||
}
|
||||
|
||||
// Validate that if BudgetID is provided, it's not an empty string
|
||||
if mc.BudgetID != nil && strings.TrimSpace(*mc.BudgetID) == "" {
|
||||
return fmt.Errorf("budget_id cannot be an empty string")
|
||||
}
|
||||
|
||||
// Validate that if RateLimitID is provided, it's not an empty string
|
||||
if mc.RateLimitID != nil && strings.TrimSpace(*mc.RateLimitID) == "" {
|
||||
return fmt.Errorf("rate_limit_id cannot be an empty string")
|
||||
}
|
||||
|
||||
// Validate that if Provider is provided, it's not an empty string
|
||||
if mc.Provider != nil && strings.TrimSpace(*mc.Provider) == "" {
|
||||
return fmt.Errorf("provider cannot be an empty string")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
framework/configstore/tables/modelparameters.go
Normal file
13
framework/configstore/tables/modelparameters.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package tables
|
||||
|
||||
// TableModelParameters stores model parameters and capabilities data
|
||||
// synced from the external datasheet API. Each row holds one model's
|
||||
// full parameter/capability JSON blob.
|
||||
type TableModelParameters struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_params_model" json:"model"`
|
||||
Data string `gorm:"type:text;not null" json:"data"` // Raw JSON blob
|
||||
}
|
||||
|
||||
// TableName sets the table name
|
||||
func (TableModelParameters) TableName() string { return "governance_model_parameters" }
|
||||
97
framework/configstore/tables/modelpricing.go
Normal file
97
framework/configstore/tables/modelpricing.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package tables
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
// TableModelPricing represents pricing information for AI models
|
||||
type TableModelPricing struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Model string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider_mode" json:"model"`
|
||||
BaseModel string `gorm:"type:varchar(255);default:null" json:"base_model,omitempty"`
|
||||
Provider string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"provider"`
|
||||
Mode string `gorm:"type:varchar(50);not null;uniqueIndex:idx_model_provider_mode" json:"mode"`
|
||||
ContextLength *int `gorm:"default:null" json:"context_length,omitempty"`
|
||||
MaxInputTokens *int `gorm:"default:null" json:"max_input_tokens,omitempty"`
|
||||
MaxOutputTokens *int `gorm:"default:null" json:"max_output_tokens,omitempty"`
|
||||
Architecture *schemas.Architecture `gorm:"type:text;serializer:json;default:null" json:"architecture,omitempty"`
|
||||
|
||||
// Costs - Text
|
||||
InputCostPerToken *float64 `gorm:"default:null" json:"input_cost_per_token,omitempty"`
|
||||
OutputCostPerToken *float64 `gorm:"default:null" json:"output_cost_per_token,omitempty"`
|
||||
InputCostPerTokenBatches *float64 `gorm:"default:null;column:input_cost_per_token_batches" json:"input_cost_per_token_batches,omitempty"`
|
||||
OutputCostPerTokenBatches *float64 `gorm:"default:null;column:output_cost_per_token_batches" json:"output_cost_per_token_batches,omitempty"`
|
||||
InputCostPerTokenPriority *float64 `gorm:"default:null;column:input_cost_per_token_priority" json:"input_cost_per_token_priority,omitempty"`
|
||||
OutputCostPerTokenPriority *float64 `gorm:"default:null;column:output_cost_per_token_priority" json:"output_cost_per_token_priority,omitempty"`
|
||||
InputCostPerTokenFlex *float64 `gorm:"default:null;column:input_cost_per_token_flex" json:"input_cost_per_token_flex,omitempty"`
|
||||
OutputCostPerTokenFlex *float64 `gorm:"default:null;column:output_cost_per_token_flex" json:"output_cost_per_token_flex,omitempty"`
|
||||
InputCostPerCharacter *float64 `gorm:"default:null;column:input_cost_per_character" json:"input_cost_per_character,omitempty"`
|
||||
// Costs - 128k Tier
|
||||
InputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_128k_tokens" json:"input_cost_per_token_above_128k_tokens,omitempty"`
|
||||
InputCostPerImageAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_image_above_128k_tokens" json:"input_cost_per_image_above_128k_tokens,omitempty"`
|
||||
InputCostPerVideoPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_video_per_second_above_128k_tokens" json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"`
|
||||
InputCostPerAudioPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_audio_per_second_above_128k_tokens" json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"`
|
||||
OutputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_128k_tokens" json:"output_cost_per_token_above_128k_tokens,omitempty"`
|
||||
// Costs - 200k Tier
|
||||
InputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens" json:"input_cost_per_token_above_200k_tokens,omitempty"`
|
||||
InputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens_priority" json:"input_cost_per_token_above_200k_tokens_priority,omitempty"`
|
||||
OutputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens" json:"output_cost_per_token_above_200k_tokens,omitempty"`
|
||||
OutputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens_priority" json:"output_cost_per_token_above_200k_tokens_priority,omitempty"`
|
||||
// Costs - 272k Tier
|
||||
InputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens" json:"input_cost_per_token_above_272k_tokens,omitempty"`
|
||||
InputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_272k_tokens_priority" json:"input_cost_per_token_above_272k_tokens_priority,omitempty"`
|
||||
OutputCostPerTokenAbove272kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens" json:"output_cost_per_token_above_272k_tokens,omitempty"`
|
||||
OutputCostPerTokenAbove272kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_272k_tokens_priority" json:"output_cost_per_token_above_272k_tokens_priority,omitempty"`
|
||||
|
||||
// Costs - Cache
|
||||
CacheCreationInputTokenCost *float64 `gorm:"default:null;column:cache_creation_input_token_cost" json:"cache_creation_input_token_cost,omitempty"`
|
||||
CacheReadInputTokenCost *float64 `gorm:"default:null;column:cache_read_input_token_cost" json:"cache_read_input_token_cost,omitempty"`
|
||||
CacheCreationInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_200k_tokens" json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"`
|
||||
CacheReadInputTokenCostAbove200kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens" json:"cache_read_input_token_cost_above_200k_tokens,omitempty"`
|
||||
CacheReadInputTokenCostAbove200kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_200k_tokens_priority" json:"cache_read_input_token_cost_above_200k_tokens_priority,omitempty"`
|
||||
CacheCreationInputTokenCostAbove1hr *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr" json:"cache_creation_input_token_cost_above_1hr,omitempty"`
|
||||
CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `gorm:"default:null;column:cache_creation_input_token_cost_above_1hr_above_200k_tokens" json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"`
|
||||
CacheCreationInputAudioTokenCost *float64 `gorm:"default:null;column:cache_creation_input_audio_token_cost" json:"cache_creation_input_audio_token_cost,omitempty"`
|
||||
CacheReadInputTokenCostPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_priority" json:"cache_read_input_token_cost_priority,omitempty"`
|
||||
CacheReadInputTokenCostFlex *float64 `gorm:"default:null;column:cache_read_input_token_cost_flex" json:"cache_read_input_token_cost_flex,omitempty"`
|
||||
CacheReadInputImageTokenCost *float64 `gorm:"default:null;column:cache_read_input_image_token_cost" json:"cache_read_input_image_token_cost,omitempty"`
|
||||
CacheReadInputTokenCostAbove272kTokens *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens" json:"cache_read_input_token_cost_above_272k_tokens,omitempty"`
|
||||
CacheReadInputTokenCostAbove272kTokensPriority *float64 `gorm:"default:null;column:cache_read_input_token_cost_above_272k_tokens_priority" json:"cache_read_input_token_cost_above_272k_tokens_priority,omitempty"`
|
||||
|
||||
// Costs - Image
|
||||
InputCostPerImage *float64 `gorm:"default:null;column:input_cost_per_image" json:"input_cost_per_image,omitempty"`
|
||||
InputCostPerPixel *float64 `gorm:"default:null;column:input_cost_per_pixel" json:"input_cost_per_pixel,omitempty"`
|
||||
OutputCostPerImage *float64 `gorm:"default:null;column:output_cost_per_image" json:"output_cost_per_image,omitempty"`
|
||||
OutputCostPerPixel *float64 `gorm:"default:null;column:output_cost_per_pixel" json:"output_cost_per_pixel,omitempty"`
|
||||
OutputCostPerImagePremiumImage *float64 `gorm:"default:null;column:output_cost_per_image_premium_image" json:"output_cost_per_image_premium_image,omitempty"`
|
||||
OutputCostPerImageAbove512x512Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_512_and_512_pixels" json:"output_cost_per_image_above_512_and_512_pixels,omitempty"`
|
||||
OutputCostPerImageAbove512x512PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_512x512_pixels_premium" json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"`
|
||||
OutputCostPerImageAbove1024x1024Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_1024_and_1024_pixels" json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"`
|
||||
OutputCostPerImageAbove1024x1024PixelsPremium *float64 `gorm:"default:null;column:output_cost_per_image_above_1024x1024_pixels_premium" json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"`
|
||||
OutputCostPerImageAbove2048x2048Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_2048_and_2048_pixels" json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"`
|
||||
OutputCostPerImageAbove4096x4096Pixels *float64 `gorm:"default:null;column:output_cost_per_image_above_4096_and_4096_pixels" json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"`
|
||||
OutputCostPerImageLowQuality *float64 `gorm:"default:null;column:output_cost_per_image_low_quality" json:"output_cost_per_image_low_quality,omitempty"`
|
||||
OutputCostPerImageMediumQuality *float64 `gorm:"default:null;column:output_cost_per_image_medium_quality" json:"output_cost_per_image_medium_quality,omitempty"`
|
||||
OutputCostPerImageHighQuality *float64 `gorm:"default:null;column:output_cost_per_image_high_quality" json:"output_cost_per_image_high_quality,omitempty"`
|
||||
OutputCostPerImageAutoQuality *float64 `gorm:"default:null;column:output_cost_per_image_auto_quality" json:"output_cost_per_image_auto_quality,omitempty"`
|
||||
InputCostPerImageToken *float64 `gorm:"default:null;column:input_cost_per_image_token" json:"input_cost_per_image_token,omitempty"`
|
||||
OutputCostPerImageToken *float64 `gorm:"default:null;column:output_cost_per_image_token" json:"output_cost_per_image_token,omitempty"`
|
||||
|
||||
// Costs - Audio/Video
|
||||
InputCostPerAudioToken *float64 `gorm:"default:null;column:input_cost_per_audio_token" json:"input_cost_per_audio_token,omitempty"`
|
||||
InputCostPerAudioPerSecond *float64 `gorm:"default:null;column:input_cost_per_audio_per_second" json:"input_cost_per_audio_per_second,omitempty"`
|
||||
InputCostPerSecond *float64 `gorm:"default:null;column:input_cost_per_second" json:"input_cost_per_second,omitempty"` // Only for transcription models
|
||||
InputCostPerVideoPerSecond *float64 `gorm:"default:null;column:input_cost_per_video_per_second" json:"input_cost_per_video_per_second,omitempty"`
|
||||
OutputCostPerAudioToken *float64 `gorm:"default:null;column:output_cost_per_audio_token" json:"output_cost_per_audio_token,omitempty"`
|
||||
OutputCostPerVideoPerSecond *float64 `gorm:"default:null;column:output_cost_per_video_per_second" json:"output_cost_per_video_per_second,omitempty"`
|
||||
OutputCostPerSecond *float64 `gorm:"default:null;column:output_cost_per_second" json:"output_cost_per_second,omitempty"` // For both speech and video models
|
||||
|
||||
// Costs - Other
|
||||
SearchContextCostPerQuery *float64 `gorm:"default:null;column:search_context_cost_per_query" json:"search_context_cost_per_query,omitempty"`
|
||||
CodeInterpreterCostPerSession *float64 `gorm:"default:null;column:code_interpreter_cost_per_session" json:"code_interpreter_cost_per_session,omitempty"`
|
||||
|
||||
// Costs - OCR
|
||||
OCRCostPerPage *float64 `gorm:"default:null;column:ocr_cost_per_page" json:"ocr_cost_per_page,omitempty"`
|
||||
AnnotationCostPerPage *float64 `gorm:"default:null;column:annotation_cost_per_page" json:"annotation_cost_per_page,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableModelPricing) TableName() string { return "governance_model_pricing" }
|
||||
379
framework/configstore/tables/oauth.go
Normal file
379
framework/configstore/tables/oauth.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableOauthConfig represents an OAuth configuration in the database
|
||||
// This stores the OAuth client configuration and flow state
|
||||
type TableOauthConfig struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
|
||||
ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients)
|
||||
ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients)
|
||||
AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered)
|
||||
TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered)
|
||||
RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered)
|
||||
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered)
|
||||
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token
|
||||
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret)
|
||||
CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider)
|
||||
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked"
|
||||
TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback)
|
||||
ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery
|
||||
UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery
|
||||
MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion)
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min)
|
||||
}
|
||||
|
||||
// TableName sets the table name
|
||||
func (TableOauthConfig) TableName() string {
|
||||
return "oauth_configs"
|
||||
}
|
||||
|
||||
// BeforeSave hook
|
||||
func (c *TableOauthConfig) BeforeSave(tx *gorm.DB) error {
|
||||
// Ensure status is valid
|
||||
if c.Status == "" {
|
||||
c.Status = "pending"
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields
|
||||
if encrypt.IsEnabled() {
|
||||
encrypted := false
|
||||
if c.ClientSecret != "" {
|
||||
if err := encryptString(&c.ClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth client secret: %w", err)
|
||||
}
|
||||
encrypted = true
|
||||
}
|
||||
if c.CodeVerifier != "" {
|
||||
if err := encryptString(&c.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth code verifier: %w", err)
|
||||
}
|
||||
encrypted = true
|
||||
}
|
||||
if encrypted {
|
||||
c.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt sensitive fields
|
||||
func (c *TableOauthConfig) AfterFind(tx *gorm.DB) error {
|
||||
if c.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&c.ClientSecret); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth client secret: %w", err)
|
||||
}
|
||||
if err := decryptString(&c.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth code verifier: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableOauthToken represents an OAuth token in the database
|
||||
// This stores the actual access and refresh tokens
|
||||
type TableOauthToken struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID
|
||||
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token
|
||||
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional)
|
||||
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
|
||||
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name
|
||||
func (TableOauthToken) TableName() string {
|
||||
return "oauth_tokens"
|
||||
}
|
||||
|
||||
// BeforeSave hook
|
||||
func (t *TableOauthToken) BeforeSave(tx *gorm.DB) error {
|
||||
// Ensure token type is set
|
||||
if t.TokenType == "" {
|
||||
t.TokenType = "Bearer"
|
||||
}
|
||||
|
||||
// Encrypt sensitive fields
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth access token: %w", err)
|
||||
}
|
||||
if err := encryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth refresh token: %w", err)
|
||||
}
|
||||
t.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt sensitive fields
|
||||
func (t *TableOauthToken) AfterFind(tx *gorm.DB) error {
|
||||
if t.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth access token: %w", err)
|
||||
}
|
||||
if err := decryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------- Per-User OAuth Tables ----------
|
||||
|
||||
// TableOauthUserSession tracks pending per-user OAuth flows.
|
||||
// Each record maps an OAuth state token to a specific MCP client, allowing
|
||||
// the callback to associate the resulting tokens with the correct user session.
|
||||
type TableOauthUserSession struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID
|
||||
MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for
|
||||
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.)
|
||||
State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider
|
||||
RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step
|
||||
CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret)
|
||||
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions)
|
||||
SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups
|
||||
GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken)
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens)
|
||||
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens)
|
||||
Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired"
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min)
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (TableOauthUserSession) TableName() string {
|
||||
return "oauth_user_sessions"
|
||||
}
|
||||
|
||||
func (s *TableOauthUserSession) BeforeSave(tx *gorm.DB) error {
|
||||
if s.Status == "" {
|
||||
s.Status = "pending"
|
||||
}
|
||||
if s.SessionToken != "" {
|
||||
s.SessionTokenHash = encrypt.HashSHA256(s.SessionToken)
|
||||
}
|
||||
if encrypt.IsEnabled() {
|
||||
if s.CodeVerifier != "" {
|
||||
if err := encryptString(&s.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth user session code verifier: %w", err)
|
||||
}
|
||||
}
|
||||
s.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TableOauthUserSession) AfterFind(tx *gorm.DB) error {
|
||||
if s.EncryptionStatus == EncryptionStatusEncrypted && s.CodeVerifier != "" {
|
||||
if err := decryptString(&s.CodeVerifier); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth user session code verifier: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableOauthUserToken stores per-user OAuth credentials.
|
||||
// Each record holds the access/refresh tokens for a specific user session + MCP client pair.
|
||||
// Lookup is by SessionToken.
|
||||
type TableOauthUserToken struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID
|
||||
SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users)
|
||||
SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions)
|
||||
UserID *string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions)
|
||||
MCPClientID string `gorm:"type:varchar(255);not null;index:idx_vk_mcp;index:idx_user_mcp" json:"mcp_client_id"` // Which MCP server
|
||||
OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config
|
||||
AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token
|
||||
RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token
|
||||
TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer"
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes
|
||||
LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (TableOauthUserToken) TableName() string {
|
||||
return "oauth_user_tokens"
|
||||
}
|
||||
|
||||
func (t *TableOauthUserToken) BeforeSave(tx *gorm.DB) error {
|
||||
if t.TokenType == "" {
|
||||
t.TokenType = "Bearer"
|
||||
}
|
||||
if t.SessionToken != "" {
|
||||
t.SessionTokenHash = encrypt.HashSHA256(t.SessionToken)
|
||||
}
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth user access token: %w", err)
|
||||
}
|
||||
if err := encryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt oauth user refresh token: %w", err)
|
||||
}
|
||||
t.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TableOauthUserToken) AfterFind(tx *gorm.DB) error {
|
||||
if t.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&t.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth user access token: %w", err)
|
||||
}
|
||||
if err := decryptString(&t.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt oauth user refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------- Per-User OAuth Authorization Server Tables ----------
|
||||
|
||||
// TablePerUserOAuthClient stores dynamically registered OAuth clients (RFC 7591).
|
||||
// MCP clients (like Claude Code) register themselves with Bifrost's OAuth
|
||||
// authorization server to obtain a client_id for the authorization code flow.
|
||||
type TablePerUserOAuthClient struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"`
|
||||
ClientName string `gorm:"type:varchar(255)" json:"client_name"`
|
||||
RedirectURIs string `gorm:"type:text;not null" json:"redirect_uris"` // JSON array of allowed redirect URIs
|
||||
GrantTypes string `gorm:"type:text" json:"grant_types"` // JSON array of grant types
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth clients.
|
||||
func (TablePerUserOAuthClient) TableName() string {
|
||||
return "oauth_per_user_clients"
|
||||
}
|
||||
|
||||
// TablePerUserOAuthSession stores Bifrost-issued access tokens for authenticated
|
||||
// MCP connections. When a user authenticates via Bifrost's OAuth flow, a session
|
||||
// is created. The access token is included in all subsequent MCP requests.
|
||||
// Upstream provider tokens are linked via the oauth_user_tokens table.
|
||||
type TablePerUserOAuthSession struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted)
|
||||
AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
|
||||
RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional)
|
||||
RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique — refresh tokens are optional)
|
||||
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth)
|
||||
VirtualKey *TableVirtualKey `gorm:"foreignKey:VirtualKeyID" json:"-"` // Linked VK identity (server-only, not serialized)
|
||||
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present)
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"`
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth sessions.
|
||||
func (TablePerUserOAuthSession) TableName() string {
|
||||
return "oauth_per_user_sessions"
|
||||
}
|
||||
|
||||
// BeforeSave encrypts sensitive fields.
|
||||
func (s *TablePerUserOAuthSession) BeforeSave(tx *gorm.DB) error {
|
||||
if s.AccessToken != "" {
|
||||
s.AccessTokenHash = encrypt.HashSHA256(s.AccessToken)
|
||||
}
|
||||
if s.RefreshToken != "" {
|
||||
s.RefreshTokenHash = encrypt.HashSHA256(s.RefreshToken)
|
||||
}
|
||||
if encrypt.IsEnabled() {
|
||||
if err := encryptString(&s.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt per-user oauth access token: %w", err)
|
||||
}
|
||||
if s.RefreshToken != "" {
|
||||
if err := encryptString(&s.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to encrypt per-user oauth refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
s.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind decrypts sensitive fields.
|
||||
func (s *TablePerUserOAuthSession) AfterFind(tx *gorm.DB) error {
|
||||
if s.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&s.AccessToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt per-user oauth access token: %w", err)
|
||||
}
|
||||
if s.RefreshToken != "" {
|
||||
if err := decryptString(&s.RefreshToken); err != nil {
|
||||
return fmt.Errorf("failed to decrypt per-user oauth refresh token: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TablePerUserOAuthCode stores authorization codes during the OAuth flow.
|
||||
// Codes are short-lived (5 minutes) and single-use.
|
||||
type TablePerUserOAuthCode struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
Code string `gorm:"type:text;not null" json:"-"` // Authorization code
|
||||
CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups
|
||||
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"`
|
||||
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"`
|
||||
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge
|
||||
Scopes string `gorm:"type:text" json:"scopes"` // JSON array of requested scopes
|
||||
SessionID string `gorm:"type:varchar(255);index" json:"-"` // Links to the TablePerUserOAuthSession created during consent submit
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 5 min TTL
|
||||
Used bool `gorm:"default:false;not null" json:"used"` // Single-use flag
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
}
|
||||
|
||||
// BeforeSave hashes the code for secure lookups.
|
||||
func (c *TablePerUserOAuthCode) BeforeSave(tx *gorm.DB) error {
|
||||
if c.Code != "" {
|
||||
c.CodeHash = encrypt.HashSHA256(c.Code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth authorization codes.
|
||||
func (TablePerUserOAuthCode) TableName() string {
|
||||
return "oauth_per_user_codes"
|
||||
}
|
||||
|
||||
// TablePerUserOAuthPendingFlow stores OAuth parameters between the authorize step
|
||||
// and the final code issuance. It carries state through the multi-step consent
|
||||
// screen (VK entry + per-MCP upstream auth) before a real authorization code is issued.
|
||||
type TablePerUserOAuthPendingFlow struct {
|
||||
ID string `gorm:"type:varchar(255);primaryKey" json:"id"`
|
||||
ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Registered OAuth client (from authorize request)
|
||||
RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Client's callback URL
|
||||
CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge (echoed into the final code)
|
||||
State string `gorm:"type:text;not null" json:"-"` // Original OAuth state (echoed back on final redirect)
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Set if user chose VK identity
|
||||
UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Set if user chose User ID identity
|
||||
BrowserSecretHash string `gorm:"type:varchar(255)" json:"-"` // SHA-256 hash of browser-binding cookie secret
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 15-min TTL
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for per-user OAuth pending flows.
|
||||
func (TablePerUserOAuthPendingFlow) TableName() string {
|
||||
return "oauth_per_user_pending_flows"
|
||||
}
|
||||
87
framework/configstore/tables/plugin.go
Normal file
87
framework/configstore/tables/plugin.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePlugin represents a plugin configuration in the database
|
||||
|
||||
type TablePlugin struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
ConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized plugin.Config
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
Version int16 `gorm:"not null;default:1" json:"version"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
IsCustom bool `gorm:"not null;default:false" json:"isCustom"`
|
||||
|
||||
Placement *schemas.PluginPlacement `gorm:"column:placement;type:varchar(20);null" json:"placement,omitempty"`
|
||||
Order *int `gorm:"column:exec_order;type:int;null" json:"order,omitempty"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
Config any `gorm:"-" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TablePlugin) TableName() string { return "config_plugins" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes the plugin Config into a JSON column and
|
||||
// encrypts it before writing to the database. Empty configs ("{}") are not encrypted.
|
||||
func (p *TablePlugin) BeforeSave(tx *gorm.DB) error {
|
||||
if p.Config != nil {
|
||||
data, err := json.Marshal(p.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ConfigJSON = string(data)
|
||||
} else {
|
||||
p.ConfigJSON = "{}"
|
||||
}
|
||||
|
||||
// Encrypt config after serialization
|
||||
if encrypt.IsEnabled() && p.ConfigJSON != "" && p.ConfigJSON != "{}" {
|
||||
encrypted, err := encrypt.Encrypt(p.ConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt plugin config: %w", err)
|
||||
}
|
||||
p.ConfigJSON = encrypted
|
||||
p.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the plugin config JSON (if encrypted) and
|
||||
// deserializes it back into the runtime Config field after reading from the database.
|
||||
func (p *TablePlugin) AfterFind(tx *gorm.DB) error {
|
||||
if p.EncryptionStatus == "encrypted" && p.ConfigJSON != "" {
|
||||
decrypted, err := encrypt.Decrypt(p.ConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt plugin config: %w", err)
|
||||
}
|
||||
p.ConfigJSON = decrypted
|
||||
}
|
||||
if p.ConfigJSON != "" {
|
||||
if err := json.Unmarshal([]byte(p.ConfigJSON), &p.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
p.Config = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
55
framework/configstore/tables/pricingoverride.go
Normal file
55
framework/configstore/tables/pricingoverride.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePricingOverride is the persistence model for governance pricing overrides.
|
||||
type TablePricingOverride struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
ScopeKind string `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"`
|
||||
VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"`
|
||||
ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"`
|
||||
ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"`
|
||||
MatchType string `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"`
|
||||
Pattern string `gorm:"type:varchar(255);not null" json:"pattern"`
|
||||
RequestTypesJSON string `gorm:"type:text" json:"-"`
|
||||
PricingPatchJSON string `gorm:"type:text" json:"pricing_patch,omitempty"`
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"`
|
||||
}
|
||||
|
||||
// TableName returns the backing table name for governance pricing overrides.
|
||||
func (TablePricingOverride) TableName() string { return "governance_pricing_overrides" }
|
||||
|
||||
// BeforeSave serializes virtual fields into their JSON columns before persistence.
|
||||
func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error {
|
||||
if len(p.RequestTypes) > 0 {
|
||||
b, err := json.Marshal(p.RequestTypes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.RequestTypesJSON = string(b)
|
||||
} else {
|
||||
p.RequestTypesJSON = "[]"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind restores virtual fields from their persisted JSON columns.
|
||||
func (p *TablePricingOverride) AfterFind(tx *gorm.DB) error {
|
||||
if p.RequestTypesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(p.RequestTypesJSON), &p.RequestTypes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
112
framework/configstore/tables/promptSessions.go
Normal file
112
framework/configstore/tables/promptSessions.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePromptSession represents a mutable working draft/session for a prompt
|
||||
// Sessions belong to a prompt and can optionally be based on a specific version
|
||||
type TablePromptSession struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
|
||||
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
|
||||
VersionID *uint `gorm:"index" json:"version_id,omitempty"` // Optional - session may or may not be based on a version
|
||||
Version *TablePromptVersion `gorm:"foreignKey:VersionID;constraint:OnDelete:SET NULL" json:"version,omitempty"`
|
||||
Name string `gorm:"type:varchar(255)" json:"name"`
|
||||
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
|
||||
ModelParams ModelParams `gorm:"-" json:"model_params"`
|
||||
Provider string `gorm:"type:varchar(100)" json:"provider"`
|
||||
Model string `gorm:"type:varchar(100)" json:"model"`
|
||||
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
|
||||
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
|
||||
// Relationships
|
||||
Messages []TablePromptSessionMessage `gorm:"foreignKey:SessionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptSession
|
||||
func (TablePromptSession) TableName() string { return "prompt_sessions" }
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (s *TablePromptSession) BeforeSave(tx *gorm.DB) error {
|
||||
data, err := json.Marshal(s.ModelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramsStr := string(data)
|
||||
s.ModelParamsJSON = ¶msStr
|
||||
|
||||
if s.Variables != nil {
|
||||
varsData, err := json.Marshal(s.Variables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
varsStr := string(varsData)
|
||||
s.VariablesJSON = &varsStr
|
||||
} else {
|
||||
s.VariablesJSON = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (s *TablePromptSession) AfterFind(tx *gorm.DB) error {
|
||||
if s.ModelParamsJSON != nil && *s.ModelParamsJSON != "" {
|
||||
dec := json.NewDecoder(strings.NewReader(*s.ModelParamsJSON))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&s.ModelParams); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if s.VariablesJSON != nil && *s.VariablesJSON != "" {
|
||||
var vars PromptVariables
|
||||
if err := json.Unmarshal([]byte(*s.VariablesJSON), &vars); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Variables = vars
|
||||
} else {
|
||||
s.Variables = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TablePromptSessionMessage represents a message in a mutable prompt session
|
||||
type TablePromptSessionMessage struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
|
||||
SessionID uint `gorm:"not null;index;uniqueIndex:idx_session_order" json:"session_id"`
|
||||
Session *TablePromptSession `gorm:"foreignKey:SessionID" json:"-"`
|
||||
OrderIndex int `gorm:"not null;uniqueIndex:idx_session_order" json:"order_index"`
|
||||
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
|
||||
Message PromptMessage `gorm:"-" json:"message"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptSessionMessage
|
||||
func (TablePromptSessionMessage) TableName() string { return "prompt_session_messages" }
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (m *TablePromptSessionMessage) BeforeSave(tx *gorm.DB) error {
|
||||
data, err := json.Marshal(m.Message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MessageJSON = string(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (m *TablePromptSessionMessage) AfterFind(tx *gorm.DB) error {
|
||||
if m.MessageJSON != "" {
|
||||
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
120
framework/configstore/tables/promptVersions.go
Normal file
120
framework/configstore/tables/promptVersions.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TablePromptVersion represents an immutable version of a prompt
|
||||
// Once created, a version cannot be modified - to make changes, create a new version
|
||||
type TablePromptVersion struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"`
|
||||
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
|
||||
VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"`
|
||||
CommitMessage string `gorm:"type:text" json:"commit_message"`
|
||||
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
|
||||
ModelParams ModelParams `gorm:"-" json:"model_params"`
|
||||
Provider string `gorm:"type:varchar(100)" json:"provider"`
|
||||
Model string `gorm:"type:varchar(100)" json:"model"`
|
||||
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
|
||||
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
|
||||
IsLatest bool `gorm:"not null;default:false" json:"is_latest"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
// No UpdatedAt - versions are immutable
|
||||
|
||||
// Relationships
|
||||
Messages []TablePromptVersionMessage `gorm:"foreignKey:VersionID;constraint:OnDelete:CASCADE" json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptVersion
|
||||
func (TablePromptVersion) TableName() string { return "prompt_versions" }
|
||||
|
||||
// ModelParams represents model configuration parameters as a flexible map
|
||||
// so that any provider-specific params (response_format, seed, logprobs, etc.) are preserved.
|
||||
type ModelParams map[string]interface{}
|
||||
|
||||
// PromptVariables represents a map of Jinja2 variable names to their values.
|
||||
// Sessions store full {key: value} pairs; versions store {key: ""} (keys only).
|
||||
type PromptVariables map[string]string
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error {
|
||||
if v.ModelParams != nil {
|
||||
data, err := json.Marshal(v.ModelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramsStr := string(data)
|
||||
v.ModelParamsJSON = ¶msStr
|
||||
}
|
||||
if v.Variables != nil {
|
||||
varsData, err := json.Marshal(v.Variables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
varsStr := string(varsData)
|
||||
v.VariablesJSON = &varsStr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (v *TablePromptVersion) AfterFind(tx *gorm.DB) error {
|
||||
if v.ModelParamsJSON != nil && *v.ModelParamsJSON != "" {
|
||||
dec := json.NewDecoder(strings.NewReader(*v.ModelParamsJSON))
|
||||
dec.UseNumber()
|
||||
if err := dec.Decode(&v.ModelParams); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if v.VariablesJSON != nil && *v.VariablesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(*v.VariablesJSON), &v.Variables); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TablePromptVersionMessage represents a message in an immutable prompt version
|
||||
type TablePromptVersionMessage struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
PromptID string `gorm:"type:varchar(36);not null;index" json:"prompt_id"`
|
||||
VersionID uint `gorm:"not null;index;uniqueIndex:idx_version_order" json:"version_id"`
|
||||
Version *TablePromptVersion `gorm:"foreignKey:VersionID" json:"-"`
|
||||
OrderIndex int `gorm:"not null;uniqueIndex:idx_version_order" json:"order_index"`
|
||||
MessageJSON string `gorm:"type:text;not null;column:message_json" json:"-"`
|
||||
Message PromptMessage `gorm:"-" json:"message"`
|
||||
}
|
||||
|
||||
// TableName for TablePromptVersionMessage
|
||||
func (TablePromptVersionMessage) TableName() string { return "prompt_version_messages" }
|
||||
|
||||
// PromptMessage is a raw JSON message stored in the database.
|
||||
// The frontend handles serialization/deserialization of the message format.
|
||||
// The backend treats it as opaque JSON to remain format-agnostic and backward-compatible.
|
||||
type PromptMessage = json.RawMessage
|
||||
|
||||
// BeforeSave GORM hook to serialize JSON fields
|
||||
func (m *TablePromptVersionMessage) BeforeSave(tx *gorm.DB) error {
|
||||
data, err := json.Marshal(m.Message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MessageJSON = string(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM hook to deserialize JSON fields
|
||||
func (m *TablePromptVersionMessage) AfterFind(tx *gorm.DB) error {
|
||||
if m.MessageJSON != "" {
|
||||
if err := json.Unmarshal([]byte(m.MessageJSON), &m.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
27
framework/configstore/tables/prompts.go
Normal file
27
framework/configstore/tables/prompts.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Package tables provides tables for the configstore
|
||||
package tables
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TablePrompt represents a prompt entity that can have multiple versions and sessions
|
||||
type TablePrompt struct {
|
||||
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
FolderID *string `gorm:"type:varchar(36);index" json:"folder_id,omitempty"`
|
||||
Folder *TableFolder `gorm:"foreignKey:FolderID;constraint:OnDelete:CASCADE" json:"folder,omitempty"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
ConfigHash string `gorm:"type:varchar(64)" json:"-"`
|
||||
|
||||
// Relationships
|
||||
Versions []TablePromptVersion `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"versions,omitempty"`
|
||||
Sessions []TablePromptSession `gorm:"foreignKey:PromptID;constraint:OnDelete:CASCADE" json:"sessions,omitempty"`
|
||||
|
||||
// Virtual fields (not stored in DB)
|
||||
LatestVersion *TablePromptVersion `gorm:"-" json:"latest_version,omitempty"`
|
||||
}
|
||||
|
||||
// TableName for TablePrompt
|
||||
func (TablePrompt) TableName() string { return "prompts" }
|
||||
184
framework/configstore/tables/provider.go
Normal file
184
framework/configstore/tables/provider.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableProvider represents a provider configuration in the database
|
||||
// NOTE: Any changes to the provider configuration should be reflected in the GenerateConfigHash function
|
||||
// That helps us detect changes between config file and database config
|
||||
type TableProvider struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Name string `gorm:"type:varchar(50);uniqueIndex;not null" json:"name"` // ModelProvider as string
|
||||
NetworkConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.NetworkConfig
|
||||
ConcurrencyBufferJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ConcurrencyAndBufferSize
|
||||
ProxyConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ProxyConfig
|
||||
CustomProviderConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.CustomProviderConfig
|
||||
OpenAIConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.OpenAIConfig
|
||||
SendBackRawRequest bool `json:"send_back_raw_request"`
|
||||
SendBackRawResponse bool `json:"send_back_raw_response"`
|
||||
StoreRawRequestResponse bool `json:"store_raw_request_response"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
|
||||
// Relationships
|
||||
Keys []TableKey `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"keys"`
|
||||
|
||||
// Virtual fields for runtime use (not stored in DB)
|
||||
NetworkConfig *schemas.NetworkConfig `gorm:"-" json:"network_config,omitempty"`
|
||||
ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `gorm:"-" json:"concurrency_and_buffer_size,omitempty"`
|
||||
ProxyConfig *schemas.ProxyConfig `gorm:"-" json:"proxy_config,omitempty"`
|
||||
|
||||
// Custom provider fields
|
||||
CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"`
|
||||
OpenAIConfig *schemas.OpenAIConfig `gorm:"-" json:"openai_config,omitempty"`
|
||||
|
||||
// Foreign keys
|
||||
Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"`
|
||||
|
||||
// Governance fields - Budget and Rate Limit for provider-level governance
|
||||
BudgetID *string `gorm:"type:varchar(255);index:idx_provider_budget" json:"budget_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index:idx_provider_rate_limit" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Governance relationships
|
||||
Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
// Model discovery status tracking for keyless providers
|
||||
Status string `gorm:"type:varchar(50);default:'unknown'" json:"status"`
|
||||
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
}
|
||||
|
||||
// TableName represents a provider configuration in the database
|
||||
func (TableProvider) TableName() string { return "config_providers" }
|
||||
|
||||
// BeforeSave is a GORM hook that serializes runtime config structs into JSON columns,
|
||||
// validates governance fields, and encrypts the proxy configuration before writing
|
||||
// to the database.
|
||||
func (p *TableProvider) BeforeSave(tx *gorm.DB) error {
|
||||
if p.NetworkConfig != nil {
|
||||
data, err := json.Marshal(p.NetworkConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.NetworkConfigJSON = string(data)
|
||||
}
|
||||
if p.ConcurrencyAndBufferSize != nil {
|
||||
data, err := json.Marshal(p.ConcurrencyAndBufferSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ConcurrencyBufferJSON = string(data)
|
||||
}
|
||||
if p.ProxyConfig != nil {
|
||||
data, err := json.Marshal(p.ProxyConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ProxyConfigJSON = string(data)
|
||||
}
|
||||
if p.CustomProviderConfig != nil && p.CustomProviderConfig.BaseProviderType == "" {
|
||||
return fmt.Errorf("base_provider_type is required when custom_provider_config is set")
|
||||
}
|
||||
if p.CustomProviderConfig != nil {
|
||||
data, err := json.Marshal(p.CustomProviderConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.CustomProviderConfigJSON = string(data)
|
||||
}
|
||||
if p.OpenAIConfig != nil {
|
||||
data, err := json.Marshal(p.OpenAIConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.OpenAIConfigJSON = string(data)
|
||||
} else {
|
||||
p.OpenAIConfigJSON = ""
|
||||
}
|
||||
// Validate governance fields
|
||||
if p.BudgetID != nil && strings.TrimSpace(*p.BudgetID) == "" {
|
||||
return fmt.Errorf("budget_id cannot be an empty string")
|
||||
}
|
||||
if p.RateLimitID != nil && strings.TrimSpace(*p.RateLimitID) == "" {
|
||||
return fmt.Errorf("rate_limit_id cannot be an empty string")
|
||||
}
|
||||
|
||||
// Encrypt proxy config after serialization (only if there's data to encrypt)
|
||||
if encrypt.IsEnabled() && p.ProxyConfigJSON != "" {
|
||||
encrypted, err := encrypt.Encrypt(p.ProxyConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt proxy config: %w", err)
|
||||
}
|
||||
p.ProxyConfigJSON = encrypted
|
||||
p.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the proxy configuration (if encrypted) and
|
||||
// deserializes JSON columns back into runtime config structs after reading from the database.
|
||||
func (p *TableProvider) AfterFind(tx *gorm.DB) error {
|
||||
if p.NetworkConfigJSON != "" {
|
||||
var config schemas.NetworkConfig
|
||||
if err := json.Unmarshal([]byte(p.NetworkConfigJSON), &config); err != nil {
|
||||
return err
|
||||
}
|
||||
p.NetworkConfig = &config
|
||||
}
|
||||
|
||||
if p.ConcurrencyBufferJSON != "" {
|
||||
var config schemas.ConcurrencyAndBufferSize
|
||||
if err := json.Unmarshal([]byte(p.ConcurrencyBufferJSON), &config); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ConcurrencyAndBufferSize = &config
|
||||
}
|
||||
|
||||
if p.EncryptionStatus == "encrypted" && p.ProxyConfigJSON != "" {
|
||||
decrypted, err := encrypt.Decrypt(p.ProxyConfigJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt proxy config: %w", err)
|
||||
}
|
||||
p.ProxyConfigJSON = decrypted
|
||||
}
|
||||
if p.ProxyConfigJSON != "" {
|
||||
var proxyConfig schemas.ProxyConfig
|
||||
if err := json.Unmarshal([]byte(p.ProxyConfigJSON), &proxyConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
p.ProxyConfig = &proxyConfig
|
||||
}
|
||||
|
||||
if p.CustomProviderConfigJSON != "" {
|
||||
var customConfig schemas.CustomProviderConfig
|
||||
if err := json.Unmarshal([]byte(p.CustomProviderConfigJSON), &customConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
p.CustomProviderConfig = &customConfig
|
||||
}
|
||||
|
||||
if p.OpenAIConfigJSON != "" {
|
||||
var openaiConfig schemas.OpenAIConfig
|
||||
if err := json.Unmarshal([]byte(p.OpenAIConfigJSON), &openaiConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
p.OpenAIConfig = &openaiConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
79
framework/configstore/tables/ratelimit.go
Normal file
79
framework/configstore/tables/ratelimit.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableRateLimit defines rate limiting rules for virtual keys using flexible max+reset approach
|
||||
type TableRateLimit struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
|
||||
// Token limits with flexible duration
|
||||
TokenMaxLimit *int64 `gorm:"default:null" json:"token_max_limit,omitempty"` // Maximum tokens allowed
|
||||
TokenResetDuration *string `gorm:"type:varchar(50)" json:"token_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
|
||||
TokenCurrentUsage int64 `gorm:"default:0" json:"token_current_usage"` // Current token usage
|
||||
TokenLastReset time.Time `gorm:"index" json:"token_last_reset"` // Last time token counter was reset
|
||||
|
||||
// Request limits with flexible duration
|
||||
RequestMaxLimit *int64 `gorm:"default:null" json:"request_max_limit,omitempty"` // Maximum requests allowed
|
||||
RequestResetDuration *string `gorm:"type:varchar(50)" json:"request_reset_duration,omitempty"` // e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y"
|
||||
RequestCurrentUsage int64 `gorm:"default:0" json:"request_current_usage"` // Current request usage
|
||||
RequestLastReset time.Time `gorm:"index" json:"request_last_reset"` // Last time request counter was reset
|
||||
|
||||
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
|
||||
|
||||
// BeforeSave hook for RateLimit to validate reset duration formats
|
||||
func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error {
|
||||
// Validate token reset duration if provided
|
||||
if rl.TokenResetDuration != nil {
|
||||
if d, err := ParseDuration(*rl.TokenResetDuration); err != nil {
|
||||
return fmt.Errorf("invalid token reset duration format: %s", *rl.TokenResetDuration)
|
||||
} else if d <= 0 {
|
||||
return fmt.Errorf("token reset duration cannot be zero or negative: %s", *rl.TokenResetDuration)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate request reset duration if provided
|
||||
if rl.RequestResetDuration != nil {
|
||||
if d, err := ParseDuration(*rl.RequestResetDuration); err != nil {
|
||||
return fmt.Errorf("invalid request reset duration format: %s", *rl.RequestResetDuration)
|
||||
} else if d <= 0 {
|
||||
return fmt.Errorf("request reset duration cannot be zero or negative: %s", *rl.RequestResetDuration)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that if a max limit is set, a reset duration is also provided
|
||||
if rl.TokenMaxLimit != nil && rl.TokenResetDuration == nil {
|
||||
return fmt.Errorf("token_reset_duration is required when token_max_limit is set")
|
||||
}
|
||||
|
||||
if rl.RequestMaxLimit != nil && rl.RequestResetDuration == nil {
|
||||
return fmt.Errorf("request_reset_duration is required when request_max_limit is set")
|
||||
}
|
||||
|
||||
// Making sure token limit is greater than zero
|
||||
if rl.TokenMaxLimit != nil && *rl.TokenMaxLimit <= 0 {
|
||||
return fmt.Errorf("token_max_limit cannot be zero or negative: %d", *rl.TokenMaxLimit)
|
||||
}
|
||||
|
||||
// Making sure request limit is greater than zero
|
||||
if rl.RequestMaxLimit != nil && *rl.RequestMaxLimit <= 0 {
|
||||
return fmt.Errorf("request_max_limit cannot be zero or negative: %d", *rl.RequestMaxLimit)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
99
framework/configstore/tables/routing_rules.go
Normal file
99
framework/configstore/tables/routing_rules.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableRoutingRule represents a routing rule in the database
|
||||
type TableRoutingRule struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
ConfigHash string `gorm:"type:varchar(255)" json:"config_hash"` // Hash of config.json version, used for change detection
|
||||
Name string `gorm:"type:varchar(255);not null;uniqueIndex:idx_routing_rule_scope_name" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Enabled bool `gorm:"not null;default:true" json:"enabled"`
|
||||
CelExpression string `gorm:"type:text;not null" json:"cel_expression"`
|
||||
|
||||
// Routing Targets (output) — 1:many relationship; weights must sum to 1
|
||||
Targets []TableRoutingTarget `gorm:"foreignKey:RuleID;constraint:OnDelete:CASCADE" json:"targets"`
|
||||
|
||||
Fallbacks *string `gorm:"type:text" json:"-"` // JSON array of fallback chains
|
||||
ParsedFallbacks []string `gorm:"-" json:"fallbacks,omitempty"` // Parsed fallbacks from JSON
|
||||
|
||||
Query *string `gorm:"type:text" json:"-"`
|
||||
ParsedQuery map[string]any `gorm:"-" json:"query,omitempty"`
|
||||
|
||||
// Scope: where this rule applies
|
||||
Scope string `gorm:"type:varchar(50);not null;uniqueIndex:idx_routing_rule_scope_name" json:"scope"` // "global" | "team" | "customer" | "virtual_key"
|
||||
ScopeID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_rule_scope_name" json:"scope_id"` // nil for global, otherwise entity ID
|
||||
|
||||
// Chaining
|
||||
ChainRule bool `gorm:"not null;default:false" json:"chain_rule"` // If true, re-evaluates routing chain after this rule matches
|
||||
|
||||
// Execution
|
||||
Priority int `gorm:"type:int;not null;default:0;index" json:"priority"` // Lower = evaluated first within scope
|
||||
|
||||
// Timestamps
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName for TableRoutingRule
|
||||
func (TableRoutingRule) TableName() string { return "routing_rules" }
|
||||
|
||||
// BeforeSave hook for TableRoutingRule to serialize JSON fields
|
||||
func (r *TableRoutingRule) BeforeSave(tx *gorm.DB) error {
|
||||
if len(r.ParsedFallbacks) > 0 {
|
||||
data, err := sonic.Marshal(r.ParsedFallbacks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Fallbacks = bifrost.Ptr(string(data))
|
||||
} else {
|
||||
r.Fallbacks = nil
|
||||
}
|
||||
if r.ParsedQuery != nil {
|
||||
data, err := sonic.Marshal(r.ParsedQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Query = bifrost.Ptr(string(data))
|
||||
} else {
|
||||
r.Query = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook for TableRoutingRule to deserialize JSON fields
|
||||
func (r *TableRoutingRule) AfterFind(tx *gorm.DB) error {
|
||||
if r.Fallbacks != nil && strings.TrimSpace(*r.Fallbacks) != "" {
|
||||
if err := sonic.Unmarshal([]byte(*r.Fallbacks), &r.ParsedFallbacks); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.Query != nil && strings.TrimSpace(*r.Query) != "" {
|
||||
if err := sonic.Unmarshal([]byte(*r.Query), &r.ParsedQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableRoutingTarget represents a weighted routing target for probabilistic routing.
|
||||
// Multiple targets can be associated with a single routing rule; weights determine
|
||||
// the probability of each target being selected and must sum to 1 across all targets in a rule.
|
||||
// The composite (RuleID, Provider, Model, KeyID) is unique to prevent duplicate target configs.
|
||||
type TableRoutingTarget struct {
|
||||
RuleID string `gorm:"type:varchar(255);not null;index;uniqueIndex:idx_routing_target_config" json:"-"`
|
||||
Provider *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"provider,omitempty"` // nil = use incoming provider
|
||||
Model *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"model,omitempty"` // nil = use incoming model
|
||||
KeyID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_target_config" json:"key_id,omitempty"` // nil = no key pin
|
||||
Weight float64 `gorm:"not null;default:1" json:"weight"` // must sum to 1 across all targets in a rule
|
||||
}
|
||||
|
||||
// TableName for TableRoutingTarget
|
||||
func (TableRoutingTarget) TableName() string { return "routing_targets" }
|
||||
48
framework/configstore/tables/sessions.go
Normal file
48
framework/configstore/tables/sessions.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SessionsTable represents a session in the database
|
||||
type SessionsTable struct {
|
||||
ID int `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Token string `gorm:"type:text;not null;uniqueIndex" json:"token"`
|
||||
ExpiresAt time.Time `gorm:"index;not null" json:"expires_at,omitempty"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
TokenHash string `gorm:"type:varchar(64);index:idx_session_token_hash,unique" json:"-"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (SessionsTable) TableName() string { return "sessions" }
|
||||
|
||||
// BeforeSave hook to hash and encrypt the session token
|
||||
func (s *SessionsTable) BeforeSave(tx *gorm.DB) error {
|
||||
// Hash must be computed before encryption (from plaintext value)
|
||||
if s.Token != "" {
|
||||
s.TokenHash = encrypt.HashSHA256(s.Token)
|
||||
}
|
||||
if encrypt.IsEnabled() && s.Token != "" {
|
||||
if err := encryptString(&s.Token); err != nil {
|
||||
return fmt.Errorf("failed to encrypt session token: %w", err)
|
||||
}
|
||||
s.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt the session token
|
||||
func (s *SessionsTable) AfterFind(tx *gorm.DB) error {
|
||||
if s.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&s.Token); err != nil {
|
||||
return fmt.Errorf("failed to decrypt session token: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
96
framework/configstore/tables/team.go
Normal file
96
framework/configstore/tables/team.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableTeam represents a team entity with budget, rate limit and customer association
|
||||
type TableTeam struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` // A team can belong to a customer
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
|
||||
Budgets []TableBudget `gorm:"foreignKey:TeamID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID" json:"rate_limit,omitempty"`
|
||||
VirtualKeys []TableVirtualKey `gorm:"foreignKey:TeamID" json:"virtual_keys,omitempty"`
|
||||
|
||||
// Computed (not a DB column) — populated via correlated subquery in query layer, hence no migration
|
||||
VirtualKeyCount int64 `gorm:"->;-:migration" json:"virtual_key_count"`
|
||||
|
||||
Profile *string `gorm:"type:text" json:"-"`
|
||||
ParsedProfile map[string]any `gorm:"-" json:"profile"`
|
||||
|
||||
Config *string `gorm:"type:text" json:"-"`
|
||||
ParsedConfig map[string]any `gorm:"-" json:"config"`
|
||||
|
||||
Claims *string `gorm:"type:text" json:"-"`
|
||||
ParsedClaims map[string]any `gorm:"-" json:"claims"`
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableTeam) TableName() string { return "governance_teams" }
|
||||
|
||||
// BeforeSave hook for TableTeam to serialize JSON fields
|
||||
func (t *TableTeam) BeforeSave(tx *gorm.DB) error {
|
||||
if t.ParsedProfile != nil {
|
||||
data, err := json.Marshal(t.ParsedProfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Profile = new(string(data))
|
||||
} else {
|
||||
t.Profile = nil
|
||||
}
|
||||
if t.ParsedConfig != nil {
|
||||
data, err := json.Marshal(t.ParsedConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Config = new(string(data))
|
||||
} else {
|
||||
t.Config = nil
|
||||
}
|
||||
if t.ParsedClaims != nil {
|
||||
data, err := json.Marshal(t.ParsedClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Claims = new(string(data))
|
||||
} else {
|
||||
t.Claims = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook for TableTeam to deserialize JSON fields
|
||||
func (t *TableTeam) AfterFind(tx *gorm.DB) error {
|
||||
if t.Profile != nil {
|
||||
if err := json.Unmarshal([]byte(*t.Profile), &t.ParsedProfile); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Config != nil {
|
||||
if err := json.Unmarshal([]byte(*t.Config), &t.ParsedConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if t.Claims != nil {
|
||||
if err := json.Unmarshal([]byte(*t.Claims), &t.ParsedClaims); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
91
framework/configstore/tables/utils.go
Normal file
91
framework/configstore/tables/utils.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IsCalendarAlignableDuration reports whether the given duration string supports calendar-aligned resets.
|
||||
// Only day ("d"), week ("w"), month ("M"), and year ("Y") suffixes have natural calendar boundaries.
|
||||
// Sub-day durations like "1h", "30m" are not alignable.
|
||||
func IsCalendarAlignableDuration(duration string) bool {
|
||||
if duration == "" {
|
||||
return false
|
||||
}
|
||||
switch duration[len(duration)-1] {
|
||||
case 'd', 'w', 'M', 'Y':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetCalendarPeriodStart returns the start of the current calendar period for the given duration and time.
|
||||
// For calendar-scale durations (daily, weekly, monthly, yearly) it snaps to clean boundaries in UTC:
|
||||
// - "Nd" → midnight UTC on the current day
|
||||
// - "Nw" → midnight UTC on the most recent Monday
|
||||
// - "NM" → midnight UTC on the 1st of the current month
|
||||
// - "NY" → midnight UTC on Jan 1 of the current year
|
||||
//
|
||||
// For all other durations (e.g. "1h", "30m") the original time t is returned unchanged,
|
||||
// since sub-day periods don't have a natural calendar boundary.
|
||||
func GetCalendarPeriodStart(duration string, t time.Time) time.Time {
|
||||
if duration == "" {
|
||||
return t
|
||||
}
|
||||
t = t.UTC()
|
||||
suffix := duration[len(duration)-1:]
|
||||
switch suffix {
|
||||
case "d":
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
case "w":
|
||||
weekday := int(t.Weekday())
|
||||
// Sunday = 0, so shift to Monday = 0
|
||||
daysFromMonday := (weekday + 6) % 7
|
||||
monday := t.AddDate(0, 0, -daysFromMonday)
|
||||
return time.Date(monday.Year(), monday.Month(), monday.Day(), 0, 0, 0, 0, time.UTC)
|
||||
case "M":
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
case "Y":
|
||||
return time.Date(t.Year(), time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// ParseDuration function to parse duration strings
|
||||
func ParseDuration(duration string) (time.Duration, error) {
|
||||
if duration == "" {
|
||||
return 0, fmt.Errorf("duration is empty")
|
||||
}
|
||||
|
||||
// Handle special cases for days, weeks, months, years
|
||||
switch {
|
||||
case duration[len(duration)-1:] == "d":
|
||||
days := duration[:len(duration)-1]
|
||||
if d, err := time.ParseDuration(days + "h"); err == nil {
|
||||
return d * 24, nil
|
||||
}
|
||||
return 0, fmt.Errorf("invalid day duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "w":
|
||||
weeks := duration[:len(duration)-1]
|
||||
if w, err := time.ParseDuration(weeks + "h"); err == nil {
|
||||
return w * 24 * 7, nil
|
||||
}
|
||||
return 0, fmt.Errorf("invalid week duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "M":
|
||||
months := duration[:len(duration)-1]
|
||||
if m, err := time.ParseDuration(months + "h"); err == nil {
|
||||
return m * 24 * 30, nil // Approximate month as 30 days
|
||||
}
|
||||
return 0, fmt.Errorf("invalid month duration: %s", duration)
|
||||
case duration[len(duration)-1:] == "Y":
|
||||
years := duration[:len(duration)-1]
|
||||
if y, err := time.ParseDuration(years + "h"); err == nil {
|
||||
return y * 24 * 365, nil // Approximate year as 365 days
|
||||
}
|
||||
return 0, fmt.Errorf("invalid year duration: %s", duration)
|
||||
default:
|
||||
return time.ParseDuration(duration)
|
||||
}
|
||||
}
|
||||
47
framework/configstore/tables/vectorstore.go
Normal file
47
framework/configstore/tables/vectorstore.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableVectorStoreConfig represents Cache plugin configuration in the database
|
||||
type TableVectorStoreConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
Enabled bool `json:"enabled"` // Enable vector store
|
||||
Type string `gorm:"type:varchar(50);not null" json:"type"` // "weaviate, redis, qdrant."
|
||||
TTLSeconds int `gorm:"default:300" json:"ttl_seconds"` // TTL in seconds (default: 5 minutes)
|
||||
CacheByModel bool `gorm:"" json:"cache_by_model"` // Include model in cache key
|
||||
CacheByProvider bool `gorm:"" json:"cache_by_provider"` // Include provider in cache key
|
||||
Config *string `gorm:"type:text" json:"config"` // JSON serialized schemas.RedisVectorStoreConfig
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVectorStoreConfig) TableName() string { return "config_vector_store" }
|
||||
|
||||
// BeforeSave hook to encrypt sensitive config
|
||||
func (vs *TableVectorStoreConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if encrypt.IsEnabled() && vs.Config != nil && *vs.Config != "" {
|
||||
if err := encryptString(vs.Config); err != nil {
|
||||
return fmt.Errorf("failed to encrypt vector store config: %w", err)
|
||||
}
|
||||
vs.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind hook to decrypt sensitive config
|
||||
func (vs *TableVectorStoreConfig) AfterFind(tx *gorm.DB) error {
|
||||
if vs.EncryptionStatus == EncryptionStatusEncrypted && vs.Config != nil && *vs.Config != "" {
|
||||
if err := decryptString(vs.Config); err != nil {
|
||||
return fmt.Errorf("failed to decrypt vector store config: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
269
framework/configstore/tables/virtualkey.go
Normal file
269
framework/configstore/tables/virtualkey.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TableVirtualKeyProviderConfigKey is the join table for the many2many relationship
|
||||
// between TableVirtualKeyProviderConfig and TableKey
|
||||
type TableVirtualKeyProviderConfigKey struct {
|
||||
TableVirtualKeyProviderConfigID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
|
||||
TableKeyID uint `gorm:"primaryKey;uniqueIndex:idx_vk_provider_config_key"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for the join table
|
||||
func (TableVirtualKeyProviderConfigKey) TableName() string {
|
||||
return "governance_virtual_key_provider_config_keys"
|
||||
}
|
||||
|
||||
// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
|
||||
type TableVirtualKeyProviderConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
|
||||
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
|
||||
Weight *float64 `json:"weight"`
|
||||
AllowedModels schemas.WhiteList `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default)
|
||||
AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default)
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Relationships
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
Budgets []TableBudget `gorm:"foreignKey:ProviderConfigID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
|
||||
Keys []TableKey `gorm:"many2many:governance_virtual_key_provider_config_keys;constraint:OnDelete:CASCADE" json:"keys"` // Empty means all keys allowed for this provider
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVirtualKeyProviderConfig) TableName() string {
|
||||
return "governance_virtual_key_provider_configs"
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaller to handle "key_ids" ([]string) config-file format
|
||||
func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias TableVirtualKeyProviderConfig
|
||||
type TempProviderConfig struct {
|
||||
Alias
|
||||
KeyIDs []string `json:"key_ids"` // Config file format: key identifiers (TableKey.KeyID); use ["*"] to allow all keys, empty denies all
|
||||
}
|
||||
|
||||
var temp TempProviderConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy all standard fields
|
||||
*pc = TableVirtualKeyProviderConfig(temp.Alias)
|
||||
|
||||
// If key_ids is provided, convert to Keys or set AllowAllKeys
|
||||
if len(temp.KeyIDs) > 0 && len(pc.Keys) == 0 {
|
||||
// ["*"] means allow all keys
|
||||
if len(temp.KeyIDs) == 1 && temp.KeyIDs[0] == "*" {
|
||||
pc.AllowAllKeys = true
|
||||
pc.Keys = nil
|
||||
} else {
|
||||
pc.AllowAllKeys = false
|
||||
pc.Keys = make([]TableKey, len(temp.KeyIDs))
|
||||
for i, keyID := range temp.KeyIDs {
|
||||
pc.Keys[i] = TableKey{KeyID: keyID}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeSave validates WhiteList fields before GORM persists the record.
|
||||
func (pc *TableVirtualKeyProviderConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if err := pc.AllowedModels.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid allowed_models: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON custom marshaller to ensure AllowedModels is always an array (never null)
|
||||
func (pc TableVirtualKeyProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias TableVirtualKeyProviderConfig
|
||||
|
||||
// Ensure AllowedModels is an empty slice instead of nil
|
||||
allowedModels := pc.AllowedModels
|
||||
if allowedModels == nil {
|
||||
allowedModels = []string{}
|
||||
}
|
||||
|
||||
return json.Marshal(&struct {
|
||||
Alias
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
}{
|
||||
Alias: Alias(pc),
|
||||
AllowedModels: allowedModels,
|
||||
})
|
||||
}
|
||||
|
||||
// AfterFind hook for TableVirtualKeyProviderConfig to clear sensitive data from associated keys
|
||||
func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error {
|
||||
if pc.Keys != nil {
|
||||
// Clear sensitive data from associated keys, keeping only key IDs and non-sensitive metadata
|
||||
for i := range pc.Keys {
|
||||
key := &pc.Keys[i]
|
||||
|
||||
// Clear the actual API key value
|
||||
key.Value = *schemas.NewEnvVar("")
|
||||
|
||||
// Clear all Azure-related sensitive fields
|
||||
key.AzureEndpoint = nil
|
||||
key.AzureAPIVersion = nil
|
||||
key.AzureClientID = nil
|
||||
key.AzureClientSecret = nil
|
||||
key.AzureTenantID = nil
|
||||
key.AzureScopesJSON = nil
|
||||
key.AzureKeyConfig = nil
|
||||
|
||||
// Clear all Vertex-related sensitive fields
|
||||
key.VertexProjectID = nil
|
||||
key.VertexProjectNumber = nil
|
||||
key.VertexRegion = nil
|
||||
key.VertexAuthCredentials = nil
|
||||
key.VertexKeyConfig = nil
|
||||
|
||||
// Clear all Bedrock-related sensitive fields
|
||||
key.BedrockAccessKey = nil
|
||||
key.BedrockSecretKey = nil
|
||||
key.BedrockSessionToken = nil
|
||||
key.BedrockRegion = nil
|
||||
key.BedrockARN = nil
|
||||
key.BedrockRoleARN = nil
|
||||
key.BedrockExternalID = nil
|
||||
key.BedrockRoleSessionName = nil
|
||||
key.BedrockKeyConfig = nil
|
||||
|
||||
pc.Keys[i] = *key
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TableVirtualKeyMCPConfig struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"`
|
||||
MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"`
|
||||
MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"`
|
||||
ToolsToExecute schemas.WhiteList `gorm:"type:text;serializer:json" json:"tools_to_execute"`
|
||||
|
||||
// MCPClientName is used during config file parsing to resolve the MCP client by name.
|
||||
// This field is not persisted to the database - it's only used to capture
|
||||
// "mcp_client_name" from config.json and then resolve it to MCPClientID.
|
||||
MCPClientName string `gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVirtualKeyMCPConfig) TableName() string {
|
||||
return "governance_virtual_key_mcp_configs"
|
||||
}
|
||||
|
||||
// BeforeSave validates WhiteList fields before GORM persists the record.
|
||||
func (mc *TableVirtualKeyMCPConfig) BeforeSave(tx *gorm.DB) error {
|
||||
if err := mc.ToolsToExecute.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid tools_to_execute: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaller to handle both "mcp_client_id" (database format)
|
||||
// and "mcp_client_name" (config file format) for MCP client references.
|
||||
func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error {
|
||||
// Temporary struct to capture all fields including mcp_client_name
|
||||
type Alias TableVirtualKeyMCPConfig
|
||||
type TempMCPConfig struct {
|
||||
Alias
|
||||
MCPClientName string `json:"mcp_client_name"` // Config file format: MCP client name
|
||||
}
|
||||
var temp TempMCPConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
// Copy all standard fields
|
||||
*mc = TableVirtualKeyMCPConfig(temp.Alias)
|
||||
// Capture mcp_client_name for later resolution to MCPClientID
|
||||
if temp.MCPClientName != "" {
|
||||
mc.MCPClientName = temp.MCPClientName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
|
||||
type TableVirtualKey struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
|
||||
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description,omitempty"`
|
||||
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:text;not null" json:"value"` // The virtual key value
|
||||
IsActive bool `gorm:"default:true" json:"is_active"`
|
||||
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means no providers allowed (deny-by-default)
|
||||
MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"`
|
||||
|
||||
// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
|
||||
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
|
||||
CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"`
|
||||
RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"`
|
||||
|
||||
// Deprecated
|
||||
// Calendar aligned is not the property of virtual key but its property of the budget and ratelimit
|
||||
// So in the migration we will move this to the budget/ratelimit table
|
||||
// And this won't be referred
|
||||
CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries
|
||||
|
||||
// Relationships
|
||||
Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"`
|
||||
Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"`
|
||||
RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"`
|
||||
Budgets []TableBudget `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals
|
||||
|
||||
// Config hash is used to detect the changes synced from config.json file
|
||||
// Every time we sync the config.json file, we will update the config hash
|
||||
ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"`
|
||||
|
||||
EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"`
|
||||
ValueHash string `gorm:"type:varchar(64);index:idx_virtual_key_value_hash,unique" json:"-"`
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName sets the table name for each model
|
||||
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
|
||||
|
||||
// BeforeSave is a GORM hook that enforces mutual exclusion (team vs customer), computes
|
||||
// a SHA-256 hash of the plaintext value for indexed lookups, and encrypts the virtual key
|
||||
// value before writing to the database.
|
||||
func (vk *TableVirtualKey) BeforeSave(tx *gorm.DB) error {
|
||||
// Enforce mutual exclusion: VK can belong to either Team OR Customer, not both
|
||||
if vk.TeamID != nil && vk.CustomerID != nil {
|
||||
return fmt.Errorf("virtual key cannot belong to both team and customer")
|
||||
}
|
||||
|
||||
// Hash must be computed before encryption (from plaintext value)
|
||||
if vk.Value != "" {
|
||||
vk.ValueHash = encrypt.HashSHA256(vk.Value)
|
||||
}
|
||||
if encrypt.IsEnabled() && vk.Value != "" {
|
||||
if err := encryptString(&vk.Value); err != nil {
|
||||
return fmt.Errorf("failed to encrypt virtual key value: %w", err)
|
||||
}
|
||||
vk.EncryptionStatus = EncryptionStatusEncrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that decrypts the virtual key value after reading from the database.
|
||||
func (vk *TableVirtualKey) AfterFind(tx *gorm.DB) error {
|
||||
if vk.EncryptionStatus == EncryptionStatusEncrypted {
|
||||
if err := decryptString(&vk.Value); err != nil {
|
||||
return fmt.Errorf("failed to decrypt virtual key value: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
40
framework/configstore/utils.go
Normal file
40
framework/configstore/utils.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package configstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// marshalToString marshals the given value to a JSON string.
|
||||
func marshalToString(v any) (string, error) {
|
||||
if v == nil {
|
||||
return "", nil
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// marshalToStringPtr marshals the given value to a JSON string and returns a pointer to the string.
|
||||
func marshalToStringPtr(v any) (*string, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := marshalToString(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// deepCopy creates a deep copy of a given type
|
||||
func deepCopy[T any](in T) (T, error) {
|
||||
var out T
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
err = json.Unmarshal(b, &out)
|
||||
return out, err
|
||||
}
|
||||
121
framework/docker-compose.yml
Normal file
121
framework/docker-compose.yml
Normal file
@@ -0,0 +1,121 @@
|
||||
# Bifrost Framework Development Services
|
||||
#
|
||||
# Supported Vector Stores:
|
||||
# - Weaviate: Runs locally via this docker-compose (port 9000)
|
||||
# - Redis: Runs locally via this docker-compose (port 6379)
|
||||
# - Qdrant: Runs locally via this docker-compose (REST: 6333, gRPC: 6334)
|
||||
# - Pinecone: Runs locally via Pinecone Local emulator (port 5081)
|
||||
# For production, use cloud service with PINECONE_API_KEY and PINECONE_INDEX_HOST
|
||||
# See: https://docs.pinecone.io/guides/operations/local-development
|
||||
#
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: bifrost-postgres-fw
|
||||
environment:
|
||||
POSTGRES_USER: bifrost
|
||||
POSTGRES_PASSWORD: bifrost_password
|
||||
POSTGRES_DB: bifrost
|
||||
PGDATA: /var/lib/postgresql/data/pgdata
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U bifrost -d bifrost"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- bifrost_network
|
||||
|
||||
redis:
|
||||
image: redis/redis-stack:latest
|
||||
container_name: bifrost-redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- bifrost_network
|
||||
|
||||
weaviate:
|
||||
image: cr.weaviate.io/semitechnologies/weaviate:1.25.0
|
||||
container_name: bifrost-weaviate
|
||||
ports:
|
||||
- "9000:8080"
|
||||
- "50051:50051"
|
||||
environment:
|
||||
QUERY_DEFAULTS_LIMIT: 25
|
||||
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
|
||||
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
|
||||
DEFAULT_VECTORIZER_MODULE: 'none'
|
||||
CLUSTER_HOSTNAME: 'node1'
|
||||
volumes:
|
||||
- weaviate_data:/var/lib/weaviate
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--spider", "-q", "http://localhost:8080/v1/.well-known/ready"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- bifrost_network
|
||||
|
||||
pinecone:
|
||||
image: ghcr.io/pinecone-io/pinecone-index:latest
|
||||
container_name: bifrost-pinecone
|
||||
environment:
|
||||
PORT: 5081
|
||||
INDEX_TYPE: serverless
|
||||
VECTOR_TYPE: dense
|
||||
DIMENSION: 1536 # Matches text-embedding-3-small dimension
|
||||
METRIC: cosine
|
||||
ports:
|
||||
- "5081:5081"
|
||||
platform: linux/amd64
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--spider", "-q", "http://localhost:5081/describe_index_stats"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- bifrost_network
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:v1.16.3
|
||||
container_name: bifrost-qdrant
|
||||
ports:
|
||||
- "6333:6333"
|
||||
- "6334:6334"
|
||||
volumes:
|
||||
- qdrant_data:/qdrant/storage
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--spider", "-q", "http://localhost:6333/readyz"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- bifrost_network
|
||||
|
||||
networks:
|
||||
bifrost_network:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
driver: local
|
||||
weaviate_data:
|
||||
driver: local
|
||||
redis_data:
|
||||
driver: local
|
||||
qdrant_data:
|
||||
driver: local
|
||||
|
||||
155
framework/encrypt/encrypt.go
Normal file
155
framework/encrypt/encrypt.go
Normal file
@@ -0,0 +1,155 @@
|
||||
// Package encrypt provides reversible AES-256-GCM encryption and decryption utilities
|
||||
// for securing sensitive data like API keys and credentials.
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var encryptionKey []byte
|
||||
var logger schemas.Logger
|
||||
|
||||
var ErrEncryptionKeyNotInitialized = errors.New("encryption key is not initialized")
|
||||
|
||||
// Init initializes the encryption key using Argon2id KDF to derive a secure 32-byte key
|
||||
// from the provided passphrase. This ensures strong entropy regardless of passphrase length.
|
||||
// The function accepts any passphrase but warns if it's too short (< 16 bytes).
|
||||
func Init(key string, _logger schemas.Logger) {
|
||||
logger = _logger
|
||||
if key == "" {
|
||||
encryptionKey = nil
|
||||
logger.Warn("encryption key is not set, encryption will be disabled. To set encryption key: use the encryption_key field in the configuration file or set the BIFROST_ENCRYPTION_KEY environment variable. Note that - once encryption key is set, it cannot be changed later unless you clean up the database.")
|
||||
return
|
||||
}
|
||||
|
||||
// Warn if passphrase is too short
|
||||
if len(key) < 16 {
|
||||
logger.Warn("encryption passphrase is shorter than 16 bytes, consider using a longer passphrase for better security")
|
||||
}
|
||||
|
||||
// Derive a secure 32-byte key using Argon2id KDF
|
||||
// We use a fixed salt since this is a system-wide encryption key (not per-user passwords)
|
||||
// Argon2id parameters: time=1, memory=64MB, threads=4, keyLen=32
|
||||
// This provides strong security while maintaining reasonable performance for initialization
|
||||
salt := []byte("bifrost-encryption-v1-salt-2024")
|
||||
encryptionKey = argon2.IDKey([]byte(key), salt, 1, 64*1024, 4, 32)
|
||||
}
|
||||
|
||||
// CompareHash compares a hash and a password
|
||||
func CompareHash(hash string, password string) (bool, error) {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
if err != nil {
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to compare hash: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Hash hashes a password using bcrypt
|
||||
func Hash(password string) (string, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
return string(hashedPassword), nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts a plaintext string using AES-256-GCM and returns a base64-encoded ciphertext
|
||||
func Encrypt(plaintext string) (string, error) {
|
||||
if encryptionKey == nil {
|
||||
return plaintext, nil
|
||||
}
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return plaintext, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return plaintext, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Create a nonce (number used once)
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return plaintext, fmt.Errorf("failed to read nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
ciphertext := aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode to base64 for storage
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// IsEnabled returns true if the encryption key has been initialized
|
||||
func IsEnabled() bool {
|
||||
return encryptionKey != nil
|
||||
}
|
||||
|
||||
// HashSHA256 returns a deterministic hex-encoded SHA-256 hash of the input.
|
||||
// Used for hash-based lookups on encrypted columns (e.g., virtual key value, session token).
|
||||
func HashSHA256(value string) string {
|
||||
h := sha256.Sum256([]byte(value))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// Decrypt decrypts a base64-encoded ciphertext using AES-256-GCM and returns the plaintext
|
||||
func Decrypt(ciphertext string) (string, error) {
|
||||
if encryptionKey == nil {
|
||||
return ciphertext, ErrEncryptionKeyNotInitialized
|
||||
}
|
||||
if ciphertext == "" {
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode base64: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
// Extract nonce
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt the data
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
245
framework/encrypt/encrypt_test.go
Normal file
245
framework/encrypt/encrypt_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Set a test encryption key
|
||||
testKey := "test-encryption-key-for-testing-32bytes"
|
||||
Init(testKey, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
plaintext string
|
||||
}{
|
||||
{
|
||||
name: "Simple text",
|
||||
plaintext: "hello world",
|
||||
},
|
||||
{
|
||||
name: "AWS Access Key",
|
||||
plaintext: "AKIAIOSFODNN7EXAMPLE",
|
||||
},
|
||||
{
|
||||
name: "AWS Secret Key",
|
||||
plaintext: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
plaintext: "",
|
||||
},
|
||||
{
|
||||
name: "Special characters",
|
||||
plaintext: "!@#$%^&*()_+-=[]{}|;':\",./<>?`~",
|
||||
},
|
||||
{
|
||||
name: "Long text",
|
||||
plaintext: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Encrypt
|
||||
encrypted, err := Encrypt(tc.plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt: %v", err)
|
||||
}
|
||||
|
||||
// For empty strings, encryption should return empty
|
||||
if tc.plaintext == "" {
|
||||
if encrypted != "" {
|
||||
t.Errorf("Expected empty string for empty input, got: %s", encrypted)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypted text should be different from plaintext
|
||||
if encrypted == tc.plaintext {
|
||||
t.Errorf("Encrypted text should be different from plaintext")
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decrypted, err := Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt: %v", err)
|
||||
}
|
||||
|
||||
// Decrypted text should match original plaintext
|
||||
if decrypted != tc.plaintext {
|
||||
t.Errorf("Decrypted text does not match original.\nExpected: %s\nGot: %s", tc.plaintext, decrypted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDeterminism(t *testing.T) {
|
||||
// Set a test encryption key
|
||||
testKey := "test-encryption-key-for-testing-32bytes"
|
||||
Init(testKey, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
|
||||
|
||||
plaintext := "test-plaintext"
|
||||
|
||||
// Encrypt the same text twice
|
||||
encrypted1, err := Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt: %v", err)
|
||||
}
|
||||
encrypted2, err := Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt: %v", err)
|
||||
}
|
||||
|
||||
// They should be different (due to random nonce)
|
||||
if encrypted1 == encrypted2 {
|
||||
t.Errorf("Two encryptions of the same plaintext should produce different ciphertexts (due to random nonce)")
|
||||
}
|
||||
|
||||
// But both should decrypt to the same plaintext
|
||||
decrypted1, err := Decrypt(encrypted1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt first: %v", err)
|
||||
}
|
||||
decrypted2, err := Decrypt(encrypted2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt second: %v", err)
|
||||
}
|
||||
|
||||
if decrypted1 != plaintext || decrypted2 != plaintext {
|
||||
t.Errorf("Both decryptions should match original plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidData(t *testing.T) {
|
||||
// Set a test encryption key
|
||||
testKey := "test-encryption-key-for-testing-32bytes"
|
||||
Init(testKey, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ciphertext string
|
||||
}{
|
||||
{
|
||||
name: "Invalid base64",
|
||||
ciphertext: "not-valid-base64!@#$",
|
||||
},
|
||||
{
|
||||
name: "Valid base64 but invalid ciphertext",
|
||||
ciphertext: "YWJjZGVmZ2hpamtsbW5vcA==",
|
||||
},
|
||||
{
|
||||
name: "Too short ciphertext",
|
||||
ciphertext: "YWJj",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := Decrypt(tc.ciphertext)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when decrypting invalid data, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKDFWithVariousKeyLengths(t *testing.T) {
|
||||
// Test that keys of various lengths work correctly with KDF
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
}{
|
||||
{
|
||||
name: "Short key (8 bytes)",
|
||||
key: "shortkey",
|
||||
},
|
||||
{
|
||||
name: "Medium key (16 bytes)",
|
||||
key: "medium-key-16byt",
|
||||
},
|
||||
{
|
||||
name: "Long key (32 bytes)",
|
||||
key: "this-is-a-32-byte-long-key!!",
|
||||
},
|
||||
{
|
||||
name: "Very long key (64 bytes)",
|
||||
key: "this-is-a-very-long-key-that-is-definitely-more-than-64-bytes",
|
||||
},
|
||||
}
|
||||
|
||||
plaintext := "test-data-for-encryption"
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Initialize with this key
|
||||
Init(tc.key, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
|
||||
|
||||
// Encrypt
|
||||
encrypted, err := Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Should produce valid ciphertext
|
||||
if encrypted == plaintext {
|
||||
t.Errorf("Encrypted text should be different from plaintext")
|
||||
}
|
||||
|
||||
// Decrypt should work
|
||||
decrypted, err := Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt with %s: %v", tc.name, err)
|
||||
}
|
||||
|
||||
if decrypted != plaintext {
|
||||
t.Errorf("Decrypted text does not match original.\nExpected: %s\nGot: %s", plaintext, decrypted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKDFDeterministic(t *testing.T) {
|
||||
// Test that the same passphrase always produces the same derived key
|
||||
passphrase := "test-passphrase"
|
||||
plaintext := "test-data"
|
||||
|
||||
// Initialize with passphrase and encrypt
|
||||
Init(passphrase, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
|
||||
encrypted1, err := Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Re-initialize with same passphrase (simulating restart)
|
||||
Init(passphrase, bifrost.NewDefaultLogger(schemas.LogLevelInfo))
|
||||
|
||||
// Should be able to decrypt the previously encrypted data
|
||||
decrypted, err := Decrypt(encrypted1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt after re-initialization: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != plaintext {
|
||||
t.Errorf("Decrypted text does not match original after re-initialization.\nExpected: %s\nGot: %s", plaintext, decrypted)
|
||||
}
|
||||
|
||||
// Encrypt again with same passphrase
|
||||
encrypted2, err := Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Should be able to decrypt both (even though they're different due to nonce)
|
||||
decrypted2, err := Decrypt(encrypted2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt second encryption: %v", err)
|
||||
}
|
||||
|
||||
if decrypted2 != plaintext {
|
||||
t.Errorf("Second decryption does not match original.\nExpected: %s\nGot: %s", plaintext, decrypted2)
|
||||
}
|
||||
}
|
||||
23
framework/envutils/utils.go
Normal file
23
framework/envutils/utils.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package envutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProcessEnvValue processes a value that might be an environment variable reference
|
||||
func ProcessEnvValue(value string) (string, error) {
|
||||
v := strings.TrimSpace(value)
|
||||
if !strings.HasPrefix(v, "env.") {
|
||||
return value, nil
|
||||
}
|
||||
envKey := strings.TrimSpace(strings.TrimPrefix(v, "env."))
|
||||
if envKey == "" {
|
||||
return "", fmt.Errorf("environment variable name missing in %q", value)
|
||||
}
|
||||
if envValue, ok := os.LookupEnv(envKey); ok {
|
||||
return envValue, nil
|
||||
}
|
||||
return "", fmt.Errorf("environment variable %s not found", envKey)
|
||||
}
|
||||
161
framework/go.mod
Normal file
161
framework/go.mod
Normal file
@@ -0,0 +1,161 @@
|
||||
module github.com/maximhq/bifrost/framework
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
cloud.google.com/go/storage v1.61.3
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/maximhq/bifrost/core v1.5.4
|
||||
github.com/pinecone-io/go-pinecone/v5 v5.3.0
|
||||
github.com/qdrant/go-client v1.16.2
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/weaviate/weaviate v1.36.5
|
||||
github.com/weaviate/weaviate-go-client/v5 v5.7.1
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/sync v0.20.0
|
||||
google.golang.org/api v0.274.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
require (
|
||||
cel.dev/expr v0.25.1 // indirect
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/iam v1.5.3 // indirect
|
||||
cloud.google.com/go/monitoring v1.24.3 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-openapi/swag/cmdutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/conv v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/fileutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/jsonname v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/jsonutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/loading v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/mangling v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/netutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/stringutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/typeutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/yamlutils v0.25.4 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.9.1 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/oapi-codegen/runtime v1.1.1 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.starlark.net v0.0.0-20260102030733-3fee463870c9 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.11
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.2 // indirect
|
||||
github.com/bytedance/sonic v1.15.0
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/go-openapi/analysis v0.24.2 // indirect
|
||||
github.com/go-openapi/errors v0.22.5 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.22.4 // indirect
|
||||
github.com/go-openapi/jsonreference v0.21.4 // indirect
|
||||
github.com/go-openapi/loads v0.23.2 // indirect
|
||||
github.com/go-openapi/runtime v0.29.2 // indirect
|
||||
github.com/go-openapi/spec v0.22.2 // indirect
|
||||
github.com/go-openapi/strfmt v0.25.0 // indirect
|
||||
github.com/go-openapi/swag v0.25.4 // indirect
|
||||
github.com/go-openapi/validate v0.25.1 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/klauspost/compress v1.18.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/mailru/easyjson v0.9.1 // indirect
|
||||
github.com/mark3labs/mcp-go v0.43.2 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||
github.com/oklog/ulid v1.3.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.68.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.mongodb.org/mongo-driver v1.17.6 // indirect
|
||||
golang.org/x/arch v0.23.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
)
|
||||
387
framework/go.sum
Normal file
387
framework/go.sum
Normal file
@@ -0,0 +1,387 @@
|
||||
cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
|
||||
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc=
|
||||
cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU=
|
||||
cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA=
|
||||
cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak=
|
||||
cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8=
|
||||
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
|
||||
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
|
||||
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
|
||||
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
|
||||
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
|
||||
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
|
||||
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM=
|
||||
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc=
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22/go.mod h1:zd/JsJ4P7oGfUhXn1VyLqaRZwPmZwg44Jf2dS84Dm3Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 h1:JRaIgADQS/U6uXDqlPiefP32yXTda7Kqfx+LgspooZM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13/go.mod h1:CEuVn5WqOMilYl+tbccq8+N2ieCy0gVn3OtRb0vBNNM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 h1:ZlvrNcHSFFWURB8avufQq9gFsheUgjVD9536obIknfM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21/go.mod h1:cv3TNhVrssKR0O/xxLJVRfd2oazSnZnkUeTf6ctUwfQ=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 h1:HwxWTbTrIHm5qY+CAEur0s/figc3qwvLWsNkF4RPToo=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3/go.mod h1:uoA43SdFwacedBfSgfFSjjCvYe8aYBS7EnU5GZ/YKMM=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
|
||||
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w=
|
||||
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
|
||||
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI=
|
||||
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
|
||||
github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE=
|
||||
github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-openapi/analysis v0.24.2 h1:6p7WXEuKy1llDgOH8FooVeO+Uq2za9qoAOq4ZN08B50=
|
||||
github.com/go-openapi/analysis v0.24.2/go.mod h1:x27OOHKANE0lutg2ml4kzYLoHGMKgRm1Cj2ijVOjJuE=
|
||||
github.com/go-openapi/errors v0.22.5 h1:Yfv4O/PRYpNF3BNmVkEizcHb3uLVVsrDt3LNdgAKRY4=
|
||||
github.com/go-openapi/errors v0.22.5/go.mod h1:z9S8ASTUqx7+CP1Q8dD8ewGH/1JWFFLX/2PmAYNQLgk=
|
||||
github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4=
|
||||
github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80=
|
||||
github.com/go-openapi/jsonreference v0.21.4 h1:24qaE2y9bx/q3uRK/qN+TDwbok1NhbSmGjjySRCHtC8=
|
||||
github.com/go-openapi/jsonreference v0.21.4/go.mod h1:rIENPTjDbLpzQmQWCj5kKj3ZlmEh+EFVbz3RTUh30/4=
|
||||
github.com/go-openapi/loads v0.23.2 h1:rJXAcP7g1+lWyBHC7iTY+WAF0rprtM+pm8Jxv1uQJp4=
|
||||
github.com/go-openapi/loads v0.23.2/go.mod h1:IEVw1GfRt/P2Pplkelxzj9BYFajiWOtY2nHZNj4UnWY=
|
||||
github.com/go-openapi/runtime v0.29.2 h1:UmwSGWNmWQqKm1c2MGgXVpC2FTGwPDQeUsBMufc5Yj0=
|
||||
github.com/go-openapi/runtime v0.29.2/go.mod h1:biq5kJXRJKBJxTDJXAa00DOTa/anflQPhT0/wmjuy+0=
|
||||
github.com/go-openapi/spec v0.22.2 h1:KEU4Fb+Lp1qg0V4MxrSCPv403ZjBl8Lx1a83gIPU8Qc=
|
||||
github.com/go-openapi/spec v0.22.2/go.mod h1:iIImLODL2loCh3Vnox8TY2YWYJZjMAKYyLH2Mu8lOZs=
|
||||
github.com/go-openapi/strfmt v0.25.0 h1:7R0RX7mbKLa9EYCTHRcCuIPcaqlyQiWNPTXwClK0saQ=
|
||||
github.com/go-openapi/strfmt v0.25.0/go.mod h1:nNXct7OzbwrMY9+5tLX4I21pzcmE6ccMGXl3jFdPfn8=
|
||||
github.com/go-openapi/swag v0.25.4 h1:OyUPUFYDPDBMkqyxOTkqDYFnrhuhi9NR6QVUvIochMU=
|
||||
github.com/go-openapi/swag v0.25.4/go.mod h1:zNfJ9WZABGHCFg2RnY0S4IOkAcVTzJ6z2Bi+Q4i6qFQ=
|
||||
github.com/go-openapi/swag/cmdutils v0.25.4 h1:8rYhB5n6WawR192/BfUu2iVlxqVR9aRgGJP6WaBoW+4=
|
||||
github.com/go-openapi/swag/cmdutils v0.25.4/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0=
|
||||
github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4=
|
||||
github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU=
|
||||
github.com/go-openapi/swag/fileutils v0.25.4 h1:2oI0XNW5y6UWZTC7vAxC8hmsK/tOkWXHJQH4lKjqw+Y=
|
||||
github.com/go-openapi/swag/fileutils v0.25.4/go.mod h1:cdOT/PKbwcysVQ9Tpr0q20lQKH7MGhOEb6EwmHOirUk=
|
||||
github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI=
|
||||
github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag=
|
||||
github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA=
|
||||
github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY=
|
||||
github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo=
|
||||
github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM=
|
||||
github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s=
|
||||
github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE=
|
||||
github.com/go-openapi/swag/mangling v0.25.4 h1:2b9kBJk9JvPgxr36V23FxJLdwBrpijI26Bx5JH4Hp48=
|
||||
github.com/go-openapi/swag/mangling v0.25.4/go.mod h1:6dxwu6QyORHpIIApsdZgb6wBk/DPU15MdyYj/ikn0Hg=
|
||||
github.com/go-openapi/swag/netutils v0.25.4 h1:Gqe6K71bGRb3ZQLusdI8p/y1KLgV4M/k+/HzVSqT8H0=
|
||||
github.com/go-openapi/swag/netutils v0.25.4/go.mod h1:m2W8dtdaoX7oj9rEttLyTeEFFEBvnAx9qHd5nJEBzYg=
|
||||
github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8=
|
||||
github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0=
|
||||
github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw=
|
||||
github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE=
|
||||
github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw=
|
||||
github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc=
|
||||
github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4=
|
||||
github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg=
|
||||
github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls=
|
||||
github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54=
|
||||
github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw=
|
||||
github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc=
|
||||
github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
|
||||
github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68=
|
||||
github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
|
||||
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU=
|
||||
github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
|
||||
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I=
|
||||
github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/maximhq/bifrost/core v1.5.4 h1:hf0BhoHVVpY1EQ4FkyRzW4IBYjrolxdZV0ucgWfHhcE=
|
||||
github.com/maximhq/bifrost/core v1.5.4/go.mod h1:z1/vOalbDAD7v7sYbXQsqR+2qIFP0jKOSIStw6Q4P4U=
|
||||
github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro=
|
||||
github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
|
||||
github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4=
|
||||
github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
|
||||
github.com/pinecone-io/go-pinecone/v5 v5.3.0 h1:0YQlEtmXGWK/I8ztkOVM6PuBYgFJZhjSdb0ddU+bHPE=
|
||||
github.com/pinecone-io/go-pinecone/v5 v5.3.0/go.mod h1:6Fg85fcyvMUQFf9KW7zniN81kelSYvsjF+KPLdc1MGA=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/qdrant/go-client v1.16.2 h1:UUMJJfvXTByhwhH1DwWdbkhZ2cTdvSqVkXSIfBrVWSg=
|
||||
github.com/qdrant/go-client v1.16.2/go.mod h1:I+EL3h4HRoRTeHtbfOd/4kDXwCukZfkd41j/9wryGkw=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 h1:qIQ0tWF9vxGtkJa24bR+2i53WBCz1nW/Pc47oVYauC4=
|
||||
github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo=
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
|
||||
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4=
|
||||
github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok=
|
||||
github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4=
|
||||
github.com/weaviate/weaviate v1.36.5 h1:lCiuEfQ08+5wK0DkTCUBb6ayNep9QpBH6JJhmZaRfzk=
|
||||
github.com/weaviate/weaviate v1.36.5/go.mod h1:ljzrgEmGKn3CRzDdcxvhmBUUZIcghwIYd1Lmn54f3Z8=
|
||||
github.com/weaviate/weaviate-go-client/v5 v5.7.1 h1:vEMxh486QqRqWaq58UEe/TiTbGbo9T5x7ZPFd5QENvQ=
|
||||
github.com/weaviate/weaviate-go-client/v5 v5.7.1/go.mod h1:T/JDErjN074GrnYIa0AgK1TGUGP/6A/8vqXNPlv4c6E=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
|
||||
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/rxAWVlHsHIZ3fT2sA=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk=
|
||||
go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg=
|
||||
golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
|
||||
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 h1:JNfk58HZ8lfmXbYK2vx/UvsqIL59TzByCxPIX4TDmsE=
|
||||
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:x5julN69+ED4PcFk/XWayw35O0lf/nGa4aNgODCmNmw=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
432
framework/kvstore/kvstore.go
Normal file
432
framework/kvstore/kvstore.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package kvstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("kvstore is closed")
|
||||
ErrEmptyKey = errors.New("key cannot be empty")
|
||||
ErrNotFound = errors.New("key not found")
|
||||
ErrInvalidTTL = errors.New("ttl cannot be negative")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCleanupInterval = 30 * time.Second
|
||||
noExpirationUnixNanos = int64(0)
|
||||
)
|
||||
|
||||
// Config controls in-memory KV store behavior.
|
||||
type Config struct {
|
||||
// CleanupInterval controls how often expired entries are removed.
|
||||
// If <= 0, defaults to 30s.
|
||||
CleanupInterval time.Duration
|
||||
// DefaultTTL applies when Set is used.
|
||||
// A zero value means entries do not expire by default.
|
||||
DefaultTTL time.Duration
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
value any
|
||||
writtenAt int64 // unix nanos, 0 means not written yet
|
||||
expiresAt int64 // unix nanos, 0 means no expiration
|
||||
}
|
||||
|
||||
// Store is an in-memory KV store with optional TTL support.
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
data map[string]entry
|
||||
|
||||
defaultTTL time.Duration
|
||||
cleanupInterval time.Duration
|
||||
|
||||
closed atomic.Bool
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
cleanupWg sync.WaitGroup
|
||||
|
||||
delegate SyncDelegate
|
||||
decoders map[string]TypeDecoder
|
||||
decoderMu sync.RWMutex
|
||||
}
|
||||
|
||||
// SyncDelegate is notified of all mutations, enabling cross-node replication.
|
||||
// All calls happen synchronously after the local mutation has succeeded.
|
||||
// writtenAt / deletedAt are absolute Unix nanosecond timestamps used by remote
|
||||
// nodes for last-write-wins conflict resolution.
|
||||
// expiresAt is an absolute Unix nanosecond timestamp; 0 means no expiration.
|
||||
type SyncDelegate interface {
|
||||
OnSet(key string, valueJSON []byte, writtenAt int64, expiresAt int64)
|
||||
OnDelete(key string, deletedAt int64)
|
||||
}
|
||||
|
||||
// TypeDecoder reconstructs a concrete value from its JSON representation.
|
||||
// Register decoders by key prefix via RegisterDecoder.
|
||||
type TypeDecoder func(data []byte) (any, error)
|
||||
|
||||
// SetDelegate plugs in the cluster sync implementation.
|
||||
func (s *Store) SetDelegate(d SyncDelegate) {
|
||||
s.delegate = d
|
||||
}
|
||||
|
||||
// RegisterDecoder registers a decoder for keys matching the given prefix.
|
||||
// Used by the receiving side to reconstruct concrete types from gossip payloads.
|
||||
func (s *Store) RegisterDecoder(keyPrefix string, decoder TypeDecoder) {
|
||||
s.decoderMu.Lock()
|
||||
s.decoders[keyPrefix] = decoder
|
||||
s.decoderMu.Unlock()
|
||||
}
|
||||
|
||||
// New creates a new in-memory KV store.
|
||||
func New(cfg Config) (*Store, error) {
|
||||
if cfg.DefaultTTL < 0 {
|
||||
return nil, ErrInvalidTTL
|
||||
}
|
||||
|
||||
cleanupInterval := cfg.CleanupInterval
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = defaultCleanupInterval
|
||||
}
|
||||
|
||||
s := &Store{
|
||||
data: make(map[string]entry),
|
||||
defaultTTL: cfg.DefaultTTL,
|
||||
cleanupInterval: cleanupInterval,
|
||||
stopCh: make(chan struct{}),
|
||||
decoders: make(map[string]TypeDecoder),
|
||||
}
|
||||
|
||||
s.cleanupWg.Add(1)
|
||||
go s.cleanupLoop()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Set stores a value using the store's default TTL.
|
||||
func (s *Store) Set(key string, value any) error {
|
||||
return s.SetWithTTL(key, value, s.defaultTTL)
|
||||
}
|
||||
|
||||
// SetWithTTL stores a value with an explicit TTL.
|
||||
// ttl=0 means no expiration.
|
||||
func (s *Store) SetWithTTL(key string, value any, ttl time.Duration) error {
|
||||
if err := s.validateMutable(key, ttl); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
var expiresAt int64
|
||||
if ttl > 0 {
|
||||
expiresAt = now + int64(ttl)
|
||||
}
|
||||
|
||||
var valueJSON []byte
|
||||
var err error
|
||||
|
||||
if s.delegate != nil {
|
||||
valueJSON, err = sonic.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.data[key] = entry{
|
||||
value: value,
|
||||
writtenAt: now,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.delegate != nil {
|
||||
s.delegate.OnSet(key, valueJSON, now, expiresAt)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetNXWithTTL atomically sets a value with TTL only if the key does not exist.
|
||||
// Returns true if the key was set, false if the key already existed.
|
||||
// ttl=0 means no expiration.
|
||||
func (s *Store) SetNXWithTTL(key string, value any, ttl time.Duration) (bool, error) {
|
||||
if err := s.validateMutable(key, ttl); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
var expiresAt int64
|
||||
if ttl > 0 {
|
||||
expiresAt = now + int64(ttl)
|
||||
}
|
||||
|
||||
var valueJSON []byte
|
||||
var err error
|
||||
|
||||
if s.delegate != nil {
|
||||
valueJSON, err = sonic.Marshal(value)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if key exists and is not expired
|
||||
if existing, ok := s.data[key]; ok {
|
||||
if !isExpired(existing, now) {
|
||||
s.mu.Unlock()
|
||||
return false, nil // Key already exists
|
||||
}
|
||||
// Key exists but is expired, allow overwrite
|
||||
}
|
||||
|
||||
// Key doesn't exist or is expired, set it
|
||||
s.data[key] = entry{
|
||||
value: value,
|
||||
writtenAt: now,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.delegate != nil {
|
||||
s.delegate.OnSet(key, valueJSON, now, expiresAt)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// SetRemote applies a remotely-gossiped entry without triggering OnSet.
|
||||
// writtenAt and expiresAt must be absolute Unix nanosecond timestamps.
|
||||
// If the local entry was written more recently than writtenAt the update is
|
||||
// silently skipped (last-write-wins by wall clock on the writing node).
|
||||
func (s *Store) SetRemote(key string, valueJSON []byte, writtenAt int64, expiresAt int64) error {
|
||||
if key == "" {
|
||||
return ErrEmptyKey
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
value := s.decodeValue(key, valueJSON)
|
||||
|
||||
s.mu.Lock()
|
||||
if existing, ok := s.data[key]; ok && existing.writtenAt > writtenAt {
|
||||
s.mu.Unlock()
|
||||
return nil // stale gossip — local entry is newer
|
||||
}
|
||||
s.data[key] = entry{value: value, writtenAt: writtenAt, expiresAt: expiresAt}
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (s *Store) Get(key string) (any, error) {
|
||||
if key == "" {
|
||||
return nil, ErrEmptyKey
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
s.mu.RLock()
|
||||
e, ok := s.data[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if isExpired(e, now) {
|
||||
s.mu.Lock()
|
||||
if latest, exists := s.data[key]; exists && isExpired(latest, time.Now().UnixNano()) {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
// GetAndDelete retrieves and deletes a key atomically.
|
||||
func (s *Store) GetAndDelete(key string) (any, error) {
|
||||
if key == "" {
|
||||
return nil, ErrEmptyKey
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
s.mu.Lock()
|
||||
e, ok := s.data[key]
|
||||
if ok {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || isExpired(e, now) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if s.delegate != nil {
|
||||
s.delegate.OnDelete(key, now)
|
||||
}
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
// Delete removes a key.
|
||||
func (s *Store) Delete(key string) (bool, error) {
|
||||
if key == "" {
|
||||
return false, ErrEmptyKey
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return false, ErrClosed
|
||||
}
|
||||
|
||||
deletedAt := time.Now().UnixNano()
|
||||
|
||||
s.mu.Lock()
|
||||
_, ok := s.data[key]
|
||||
if ok {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
if s.delegate != nil {
|
||||
s.delegate.OnDelete(key, deletedAt)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// DeleteRemote applies a remotely-gossiped delete without triggering OnDelete.
|
||||
// deletedAt is the absolute Unix nanosecond timestamp when the delete was issued.
|
||||
// The delete is skipped if the local entry was written after the delete intent
|
||||
// (last-write-wins).
|
||||
func (s *Store) DeleteRemote(key string, deletedAt int64) error {
|
||||
if key == "" {
|
||||
return ErrEmptyKey
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if existing, ok := s.data[key]; ok && existing.writtenAt > deletedAt {
|
||||
s.mu.Unlock()
|
||||
return nil // entry was written after the delete intent — write wins
|
||||
}
|
||||
delete(s.data, key)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the number of currently non-expired keys.
|
||||
func (s *Store) Len() int {
|
||||
if s.closed.Load() {
|
||||
return 0
|
||||
}
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
total := 0
|
||||
|
||||
s.mu.RLock()
|
||||
for _, v := range s.data {
|
||||
if isExpired(v, now) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// Close stops background cleanup and prevents further operations.
|
||||
func (s *Store) Close() error {
|
||||
s.stopOnce.Do(func() {
|
||||
s.closed.Store(true)
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.cleanupWg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) cleanupLoop() {
|
||||
defer s.cleanupWg.Done()
|
||||
|
||||
ticker := time.NewTicker(s.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.cleanupExpired()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) cleanupExpired() {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
s.mu.Lock()
|
||||
for k, v := range s.data {
|
||||
if isExpired(v, now) {
|
||||
delete(s.data, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Store) validateMutable(key string, ttl time.Duration) error {
|
||||
if key == "" {
|
||||
return ErrEmptyKey
|
||||
}
|
||||
if ttl < 0 {
|
||||
return ErrInvalidTTL
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeValue uses the registered decoder for the key's prefix, falling back
|
||||
// to raw []byte if no decoder matches.
|
||||
func (s *Store) decodeValue(key string, valueJSON []byte) any {
|
||||
s.decoderMu.RLock()
|
||||
|
||||
var bestPrefix string
|
||||
var bestDecode TypeDecoder
|
||||
for prefix, decode := range s.decoders {
|
||||
if strings.HasPrefix(key, prefix) && len(prefix) > len(bestPrefix) {
|
||||
bestPrefix = prefix
|
||||
bestDecode = decode
|
||||
}
|
||||
}
|
||||
|
||||
s.decoderMu.RUnlock()
|
||||
|
||||
if bestDecode != nil {
|
||||
if v, err := bestDecode(valueJSON); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return valueJSON
|
||||
}
|
||||
|
||||
func isExpired(e entry, nowUnixNano int64) bool {
|
||||
return e.expiresAt != noExpirationUnixNanos && nowUnixNano >= e.expiresAt
|
||||
}
|
||||
100
framework/kvstore/kvstore_test.go
Normal file
100
framework/kvstore/kvstore_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package kvstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStoreSetGetDelete(t *testing.T) {
|
||||
store, err := New(Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
if err := store.Set("k1", "v1"); err != nil {
|
||||
t.Fatalf("set failed: %v", err)
|
||||
}
|
||||
|
||||
v, err := store.Get("k1")
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if v.(string) != "v1" {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
deleted, err := store.Delete("k1")
|
||||
if err != nil {
|
||||
t.Fatalf("delete failed: %v", err)
|
||||
}
|
||||
if !deleted {
|
||||
t.Fatal("expected key to be deleted")
|
||||
}
|
||||
|
||||
if _, err := store.Get("k1"); !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreTTLExpiration(t *testing.T) {
|
||||
store, err := New(Config{
|
||||
CleanupInterval: 10 * time.Millisecond,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
if err := store.SetWithTTL("exp", "value", 25*time.Millisecond); err != nil {
|
||||
t.Fatalf("set with ttl failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if _, err := store.Get("exp"); !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound after expiry, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreGetAndDelete(t *testing.T) {
|
||||
store, err := New(Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
if err := store.Set("k", "v"); err != nil {
|
||||
t.Fatalf("set failed: %v", err)
|
||||
}
|
||||
|
||||
v, err := store.GetAndDelete("k")
|
||||
if err != nil {
|
||||
t.Fatalf("get and delete failed: %v", err)
|
||||
}
|
||||
if v.(string) != "v" {
|
||||
t.Fatalf("unexpected value: %v", v)
|
||||
}
|
||||
|
||||
if _, err := store.Get("k"); !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("expected missing key after get-and-delete, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreClose(t *testing.T) {
|
||||
store, err := New(Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("close failed: %v", err)
|
||||
}
|
||||
|
||||
if err := store.Set("k", "v"); !errors.Is(err, ErrClosed) {
|
||||
t.Fatalf("expected ErrClosed on set, got: %v", err)
|
||||
}
|
||||
if _, err := store.Get("k"); !errors.Is(err, ErrClosed) {
|
||||
t.Fatalf("expected ErrClosed on get, got: %v", err)
|
||||
}
|
||||
}
|
||||
14
framework/list.go
Normal file
14
framework/list.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Package framework provides a list of dependencies that are required for the framework to work.
|
||||
package framework
|
||||
|
||||
// FrameworkDependency is a type that represents a dependency of the framework.
|
||||
type FrameworkDependency string
|
||||
|
||||
const (
|
||||
// FrameworkDependencyVectorStore indicates the framework requires a VectorStore implementation.
|
||||
FrameworkDependencyVectorStore FrameworkDependency = "vector_store"
|
||||
// FrameworkDependencyConfigStore indicates the framework requires a ConfigStore implementation.
|
||||
FrameworkDependencyConfigStore FrameworkDependency = "config_store"
|
||||
// FrameworkDependencyLogsStore indicates the framework requires a LogsStore implementation.
|
||||
FrameworkDependencyLogsStore FrameworkDependency = "logs_store"
|
||||
)
|
||||
318
framework/logstore/asyncjob.go
Normal file
318
framework/logstore/asyncjob.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/google/uuid"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultAsyncJobResultTTL is the default TTL for async job results in seconds (1 hour).
|
||||
DefaultAsyncJobResultTTL = 3600
|
||||
)
|
||||
|
||||
const (
|
||||
asyncJobCleanupInterval = 1 * time.Minute
|
||||
asyncJobCleanupTimeout = 1 * time.Minute
|
||||
asyncJobStaleProcessingHours = 24
|
||||
)
|
||||
|
||||
// --- AsyncJobExecutor ---
|
||||
|
||||
// AsyncOperation represents a function that can be executed asynchronously.
|
||||
// It returns the response and an optional BifrostError.
|
||||
type AsyncOperation func(ctx *schemas.BifrostContext) (any, *schemas.BifrostError)
|
||||
|
||||
// GovernanceStore is an interface that provides access to the governance store.
|
||||
type GovernanceStore interface {
|
||||
GetVirtualKey(ctx context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool)
|
||||
}
|
||||
|
||||
// AsyncJobExecutor manages async job creation and background execution.
|
||||
type AsyncJobExecutor struct {
|
||||
logstore LogStore
|
||||
governanceStore GovernanceStore
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewAsyncJobExecutor creates a new AsyncJobExecutor.
|
||||
func NewAsyncJobExecutor(logstore LogStore, governanceStore GovernanceStore, logger schemas.Logger) *AsyncJobExecutor {
|
||||
return &AsyncJobExecutor{
|
||||
logstore: logstore,
|
||||
governanceStore: governanceStore,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RetrieveJob retrieves a job by its ID.
|
||||
func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValue *string, operationType schemas.RequestType) (*AsyncJob, error) {
|
||||
job, err := e.logstore.FindAsyncJobByID(ctx, jobID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return nil, fmt.Errorf("job not found or expired")
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %w", ErrJobInternal, err)
|
||||
}
|
||||
if job.VirtualKeyID != nil {
|
||||
if vkValue == nil {
|
||||
return nil, fmt.Errorf("virtual key is required")
|
||||
}
|
||||
vk, ok := e.governanceStore.GetVirtualKey(ctx, *vkValue)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
if *job.VirtualKeyID != vk.ID {
|
||||
return nil, fmt.Errorf("virtual key mismatch")
|
||||
}
|
||||
}
|
||||
if job.RequestType != operationType {
|
||||
return nil, fmt.Errorf("operation type mismatch")
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// SubmitJob creates a pending job, starts background execution, and returns the job record.
|
||||
func (e *AsyncJobExecutor) SubmitJob(bifrostCtx *schemas.BifrostContext, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) {
|
||||
if resultTTL <= 0 {
|
||||
resultTTL = DefaultAsyncJobResultTTL
|
||||
}
|
||||
|
||||
virtualKeyValue := getVirtualKeyFromContext(bifrostCtx)
|
||||
|
||||
var virtualKeyID *string
|
||||
if virtualKeyValue != nil {
|
||||
vk, ok := e.governanceStore.GetVirtualKey(bifrostCtx, *virtualKeyValue)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
virtualKeyID = &vk.ID
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
job := &AsyncJob{
|
||||
ID: uuid.New().String(),
|
||||
Status: schemas.AsyncJobStatusPending,
|
||||
RequestType: operationType,
|
||||
VirtualKeyID: virtualKeyID,
|
||||
ResultTTL: resultTTL,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := e.logstore.CreateAsyncJob(ctx, job); err != nil {
|
||||
return nil, fmt.Errorf("failed to create async job: %w", err)
|
||||
}
|
||||
|
||||
var contextValues map[any]any
|
||||
if bifrostCtx != nil {
|
||||
contextValues = bifrostCtx.GetUserValues()
|
||||
}
|
||||
go e.executeJob(job.ID, job.ResultTTL, operation, contextValues)
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// executeJob runs the operation in the background and updates the job record.
|
||||
func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation, contextValues map[any]any) {
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
// Restore original request context values (virtual key, tracing headers, etc.)
|
||||
for k, v := range contextValues {
|
||||
ctx.SetValue(k, v)
|
||||
}
|
||||
|
||||
// Clear trace context inherited from the original HTTP request.
|
||||
ctx.ClearValue(schemas.BifrostContextKeyTraceID)
|
||||
ctx.ClearValue(schemas.BifrostContextKeyParentSpanID)
|
||||
ctx.ClearValue(schemas.BifrostContextKeySpanID)
|
||||
|
||||
markFailed := func(msg string) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(time.Duration(resultTTL) * time.Second)
|
||||
errJSON, _ := sonic.Marshal(&schemas.BifrostError{Error: &schemas.ErrorField{Message: msg}})
|
||||
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]any{
|
||||
"status": schemas.AsyncJobStatusFailed,
|
||||
"status_code": fasthttp.StatusInternalServerError,
|
||||
"error": string(errJSON),
|
||||
"completed_at": now,
|
||||
"expires_at": expiresAt,
|
||||
}); err != nil {
|
||||
e.logger.Warn("failed to update async job to failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// The bifrost execution flow is very stable and panics are not expected.
|
||||
// This recover is purely defensive to ensure the job always reaches a terminal
|
||||
// state rather than being stuck in "processing" if an unexpected panic occurs.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
e.logger.Warn("async job %s panicked: %v", jobID, r)
|
||||
markFailed(fmt.Sprintf("internal error: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
// Mark as processing
|
||||
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{
|
||||
"status": schemas.AsyncJobStatusProcessing,
|
||||
}); err != nil {
|
||||
e.logger.Warn("failed to update async job: %v", err)
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostIsAsyncRequest, true)
|
||||
|
||||
// Execute the operation
|
||||
resp, bifrostErr := operation(ctx)
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(time.Duration(resultTTL) * time.Second)
|
||||
|
||||
if bifrostErr != nil {
|
||||
errJSON, err := sonic.Marshal(bifrostErr)
|
||||
if err != nil {
|
||||
e.logger.Warn("failed to marshal bifrost error: %v", err)
|
||||
markFailed(fmt.Sprintf("failed to serialize error response: %v", err))
|
||||
return
|
||||
}
|
||||
statusCode := fasthttp.StatusInternalServerError
|
||||
if bifrostErr.StatusCode != nil {
|
||||
statusCode = *bifrostErr.StatusCode
|
||||
}
|
||||
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{
|
||||
"status": schemas.AsyncJobStatusFailed,
|
||||
"status_code": statusCode,
|
||||
"error": string(errJSON),
|
||||
"completed_at": now,
|
||||
"expires_at": expiresAt,
|
||||
}); err != nil {
|
||||
e.logger.Warn("failed to update async job: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
respJSON, err := sonic.Marshal(resp)
|
||||
if err != nil {
|
||||
e.logger.Warn("failed to marshal result: %v", err)
|
||||
markFailed(fmt.Sprintf("failed to serialize result: %v", err))
|
||||
return
|
||||
}
|
||||
if err := e.logstore.UpdateAsyncJob(ctx, jobID, map[string]interface{}{
|
||||
"status": schemas.AsyncJobStatusCompleted,
|
||||
"status_code": fasthttp.StatusOK,
|
||||
"response": string(respJSON),
|
||||
"completed_at": now,
|
||||
"expires_at": expiresAt,
|
||||
}); err != nil {
|
||||
e.logger.Warn("failed to update async job: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Cleaner ---
|
||||
|
||||
// AsyncJobCleaner manages the cleanup of expired async jobs.
|
||||
type AsyncJobCleaner struct {
|
||||
store LogStore
|
||||
logger schemas.Logger
|
||||
stopCleanup chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewAsyncJobCleaner creates a new AsyncJobCleaner instance.
|
||||
func NewAsyncJobCleaner(store LogStore, logger schemas.Logger) *AsyncJobCleaner {
|
||||
return &AsyncJobCleaner{
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanupRoutine starts a goroutine that periodically cleans up expired async jobs.
|
||||
func (c *AsyncJobCleaner) StartCleanupRoutine() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.stopCleanup != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.stopCleanup = make(chan struct{})
|
||||
stopCh := c.stopCleanup
|
||||
|
||||
go func() {
|
||||
// Run initial cleanup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), asyncJobCleanupTimeout)
|
||||
c.cleanupExpiredJobs(ctx)
|
||||
cancel()
|
||||
|
||||
ticker := time.NewTicker(asyncJobCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), asyncJobCleanupTimeout)
|
||||
c.cleanupExpiredJobs(ctx)
|
||||
cancel()
|
||||
case <-stopCh:
|
||||
c.logger.Debug("async job cleanup routine stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
c.logger.Debug("async job cleanup routine started (interval: %s)", asyncJobCleanupInterval)
|
||||
}
|
||||
|
||||
// StopCleanupRoutine gracefully stops the cleanup goroutine.
|
||||
func (c *AsyncJobCleaner) StopCleanupRoutine() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.stopCleanup == nil {
|
||||
c.logger.Debug("async job cleanup routine already stopped")
|
||||
return
|
||||
}
|
||||
|
||||
close(c.stopCleanup)
|
||||
c.stopCleanup = nil
|
||||
}
|
||||
|
||||
// cleanupExpiredJobs deletes expired async jobs and stale processing jobs.
|
||||
func (c *AsyncJobCleaner) cleanupExpiredJobs(ctx context.Context) {
|
||||
deleted, err := c.store.DeleteExpiredAsyncJobs(ctx)
|
||||
if err != nil {
|
||||
c.logger.Warn("failed to delete expired async jobs: %v", err)
|
||||
} else if deleted > 0 {
|
||||
c.logger.Debug("async job cleanup completed: deleted %d expired jobs", deleted)
|
||||
}
|
||||
|
||||
// Clean up jobs stuck in "processing" for more than 24 hours
|
||||
// This handles edge cases like marshal failures or server crashes
|
||||
staleSince := time.Now().UTC().Add(-asyncJobStaleProcessingHours * time.Hour)
|
||||
staleDeleted, err := c.store.DeleteStaleAsyncJobs(ctx, staleSince)
|
||||
if err != nil {
|
||||
c.logger.Warn("failed to delete stale processing async jobs: %v", err)
|
||||
} else if staleDeleted > 0 {
|
||||
c.logger.Warn("async job cleanup: deleted %d stale processing jobs (stuck > %dh)", staleDeleted, asyncJobStaleProcessingHours)
|
||||
}
|
||||
}
|
||||
|
||||
// getVirtualKeyFromContext extracts the virtual key value from context.
|
||||
// Returns nil if no VK is present (e.g., direct key mode or no governance),
|
||||
// or if the context itself is nil (callers like SubmitJob may be invoked with
|
||||
// a nil ctx by background paths that don't carry a VK).
|
||||
func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
if vkValue == "" {
|
||||
return nil
|
||||
}
|
||||
return &vkValue
|
||||
}
|
||||
213
framework/logstore/asyncjob_test.go
Normal file
213
framework/logstore/asyncjob_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type asyncTestLogger struct{}
|
||||
|
||||
func (asyncTestLogger) Debug(string, ...any) {}
|
||||
func (asyncTestLogger) Info(string, ...any) {}
|
||||
func (asyncTestLogger) Warn(string, ...any) {}
|
||||
func (asyncTestLogger) Error(string, ...any) {}
|
||||
func (asyncTestLogger) Fatal(string, ...any) {}
|
||||
func (asyncTestLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (asyncTestLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (asyncTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
type testGovernanceStore struct {
|
||||
virtualKeys map[string]*configstoreTables.TableVirtualKey
|
||||
}
|
||||
|
||||
func (t *testGovernanceStore) GetVirtualKey(_ context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) {
|
||||
vk, ok := t.virtualKeys[vkValue]
|
||||
return vk, ok
|
||||
}
|
||||
|
||||
func newTestAsyncExecutor(t *testing.T) *AsyncJobExecutor {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
store, err := newSqliteLogStore(ctx, &SQLiteConfig{Path: ":memory:"}, asyncTestLogger{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { store.Close(ctx) })
|
||||
|
||||
govStore := &testGovernanceStore{
|
||||
virtualKeys: map[string]*configstoreTables.TableVirtualKey{
|
||||
"sk-bf-test": {ID: "vk-123", Value: "sk-bf-test"},
|
||||
},
|
||||
}
|
||||
|
||||
return NewAsyncJobExecutor(store, govStore, asyncTestLogger{})
|
||||
}
|
||||
|
||||
// waitForJobCompletion polls until the operation callback has been invoked.
|
||||
func waitForJobCompletion(t *testing.T, done *atomic.Bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if done.Load() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timed out waiting for async job execution")
|
||||
}
|
||||
|
||||
// waitForJobStatus polls FindAsyncJobByID until the job reaches a terminal
|
||||
// status (completed or failed), or times out. This avoids a fragile time.Sleep
|
||||
// between the operation callback completing and the DB update finishing.
|
||||
// Processing is intermediate and must not be treated as terminal.
|
||||
func waitForJobStatus(t *testing.T, store LogStore, jobID string) *AsyncJob {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
job, err := store.FindAsyncJobByID(context.Background(), jobID)
|
||||
if err == nil && (job.Status == schemas.AsyncJobStatusCompleted || job.Status == schemas.AsyncJobStatusFailed) {
|
||||
return job
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timed out waiting for async job to reach terminal status")
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSubmitJob_PropagatesContextValues(t *testing.T) {
|
||||
executor := newTestAsyncExecutor(t)
|
||||
|
||||
capturedCtx := schemas.NewBifrostContext(context.Background(), <-time.After(1*time.Minute))
|
||||
capturedCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test")
|
||||
capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-eh-custom"), "custom-value")
|
||||
capturedCtx.SetValue(schemas.BifrostContextKey("x-bf-prom-env"), "production")
|
||||
var done atomic.Bool
|
||||
|
||||
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
capturedCtx = bgCtx
|
||||
done.Store(true)
|
||||
return map[string]string{"status": "ok"}, nil
|
||||
}
|
||||
|
||||
job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, job)
|
||||
|
||||
waitForJobCompletion(t, &done)
|
||||
|
||||
assert.Equal(t, "sk-bf-test", capturedCtx.Value(schemas.BifrostContextKeyVirtualKey))
|
||||
assert.Equal(t, "production", capturedCtx.Value(schemas.BifrostContextKey("x-bf-prom-env")))
|
||||
assert.Equal(t, "custom-value", capturedCtx.Value(schemas.BifrostContextKey("x-bf-eh-custom")))
|
||||
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
|
||||
}
|
||||
|
||||
func TestSubmitJob_NilContextValues(t *testing.T) {
|
||||
executor := newTestAsyncExecutor(t)
|
||||
|
||||
var capturedCtx *schemas.BifrostContext
|
||||
var done atomic.Bool
|
||||
|
||||
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
capturedCtx = bgCtx
|
||||
done.Store(true)
|
||||
return map[string]string{"status": "ok"}, nil
|
||||
}
|
||||
|
||||
job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, job)
|
||||
|
||||
waitForJobCompletion(t, &done)
|
||||
|
||||
assert.NotNil(t, capturedCtx)
|
||||
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
|
||||
}
|
||||
|
||||
func TestSubmitJob_EmptyContextValues(t *testing.T) {
|
||||
executor := newTestAsyncExecutor(t)
|
||||
|
||||
var capturedCtx *schemas.BifrostContext
|
||||
var done atomic.Bool
|
||||
|
||||
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
capturedCtx = bgCtx
|
||||
done.Store(true)
|
||||
return map[string]string{"status": "ok"}, nil
|
||||
}
|
||||
|
||||
job, err := executor.SubmitJob(capturedCtx, 3600, operation, schemas.ChatCompletionRequest)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, job)
|
||||
|
||||
waitForJobCompletion(t, &done)
|
||||
|
||||
assert.NotNil(t, capturedCtx)
|
||||
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
|
||||
}
|
||||
|
||||
func TestSubmitJob_AsyncFlagOverridesContextValues(t *testing.T) {
|
||||
executor := newTestAsyncExecutor(t)
|
||||
|
||||
inputCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
inputCtx.SetValue(schemas.BifrostIsAsyncRequest, false)
|
||||
|
||||
var capturedCtx *schemas.BifrostContext
|
||||
var done atomic.Bool
|
||||
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
capturedCtx = bgCtx
|
||||
done.Store(true)
|
||||
return map[string]string{"status": "ok"}, nil
|
||||
}
|
||||
|
||||
job, err := executor.SubmitJob(inputCtx, 3600, operation, schemas.ChatCompletionRequest)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, job)
|
||||
|
||||
waitForJobCompletion(t, &done)
|
||||
|
||||
// BifrostIsAsyncRequest must be true — set AFTER restoring context values
|
||||
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
|
||||
}
|
||||
|
||||
func TestSubmitJob_OperationFailure_PreservesContext(t *testing.T) {
|
||||
executor := newTestAsyncExecutor(t)
|
||||
|
||||
inputCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
inputCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-test")
|
||||
|
||||
var capturedCtx *schemas.BifrostContext
|
||||
var done atomic.Bool
|
||||
|
||||
statusCode := fasthttp.StatusBadRequest
|
||||
operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
capturedCtx = bgCtx
|
||||
done.Store(true)
|
||||
return nil, &schemas.BifrostError{
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{Message: "test error"},
|
||||
}
|
||||
}
|
||||
|
||||
job, err := executor.SubmitJob(inputCtx, 3600, operation, schemas.ChatCompletionRequest)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, job)
|
||||
|
||||
waitForJobCompletion(t, &done)
|
||||
|
||||
// Context values should still be available even when operation fails
|
||||
assert.Equal(t, "sk-bf-test", capturedCtx.Value(schemas.BifrostContextKeyVirtualKey))
|
||||
assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest))
|
||||
|
||||
// Verify job was marked as failed — poll until DB update completes
|
||||
retrievedJob := waitForJobStatus(t, executor.logstore, job.ID)
|
||||
assert.Equal(t, schemas.AsyncJobStatusFailed, retrievedJob.Status)
|
||||
}
|
||||
161
framework/logstore/cleaner.go
Normal file
161
framework/logstore/cleaner.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
cleanupInterval = 24 * time.Hour
|
||||
minJitter = 15 * time.Minute
|
||||
maxJitter = 30 * time.Minute
|
||||
batchSize = 100
|
||||
defaultRetentionDays = 365
|
||||
)
|
||||
|
||||
// LogRetentionManager defines the interface for managing log retention and deletion
|
||||
type LogRetentionManager interface {
|
||||
DeleteLogsBatch(ctx context.Context, cutoff time.Time, batchSize int) (deletedCount int64, err error)
|
||||
}
|
||||
|
||||
// CleanerConfig holds configuration for the log cleaner
|
||||
type CleanerConfig struct {
|
||||
RetentionDays int
|
||||
}
|
||||
|
||||
// LogsCleaner manages the cleanup of old logs
|
||||
type LogsCleaner struct {
|
||||
manager LogRetentionManager
|
||||
config CleanerConfig
|
||||
logger schemas.Logger
|
||||
stopCleanup chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewLogsCleaner creates a new LogsCleaner instance
|
||||
func NewLogsCleaner(manager LogRetentionManager, config CleanerConfig, logger schemas.Logger) *LogsCleaner {
|
||||
return &LogsCleaner{
|
||||
manager: manager,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanupRoutine starts a goroutine that periodically cleans up old logs
|
||||
func (c *LogsCleaner) StartCleanupRoutine() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Return early if already running
|
||||
if c.stopCleanup != nil {
|
||||
c.logger.Debug("log cleanup routine already running")
|
||||
return
|
||||
}
|
||||
|
||||
c.stopCleanup = make(chan struct{})
|
||||
stopCh := c.stopCleanup
|
||||
|
||||
go func() {
|
||||
// At the beginning, we will cleanup the logs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
c.cleanupOldLogs(ctx)
|
||||
cancel()
|
||||
// Calculate initial delay with jitter
|
||||
timer := time.NewTimer(calculateNextRunDuration())
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
// Run cleanup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
c.cleanupOldLogs(ctx)
|
||||
cancel()
|
||||
|
||||
// Reset timer with new jitter for next run
|
||||
timer.Reset(calculateNextRunDuration())
|
||||
|
||||
case <-stopCh:
|
||||
c.logger.Info("log cleanup routine stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
c.logger.Info("log cleanup routine started")
|
||||
}
|
||||
|
||||
// StopCleanupRoutine gracefully stops the cleanup goroutine
|
||||
func (c *LogsCleaner) StopCleanupRoutine() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Return early if already stopped
|
||||
if c.stopCleanup == nil {
|
||||
c.logger.Debug("log cleanup routine already stopped")
|
||||
return
|
||||
}
|
||||
|
||||
close(c.stopCleanup)
|
||||
c.stopCleanup = nil
|
||||
}
|
||||
|
||||
// cleanupOldLogs deletes logs older than the retention period in batches
|
||||
func (c *LogsCleaner) cleanupOldLogs(ctx context.Context) {
|
||||
retentionDays := c.config.RetentionDays
|
||||
if retentionDays < 1 {
|
||||
retentionDays = defaultRetentionDays
|
||||
}
|
||||
|
||||
// Calculate cutoff time
|
||||
cutoff := time.Now().UTC().AddDate(0, 0, -retentionDays)
|
||||
c.logger.Info("starting log cleanup: deleting logs older than %s (retention: %d days)", cutoff.Format(time.RFC3339), retentionDays)
|
||||
|
||||
totalDeleted := int64(0)
|
||||
batchCount := 0
|
||||
|
||||
for {
|
||||
// Check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Warn("log cleanup cancelled: %v", ctx.Err())
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Delete logs in batches using the manager
|
||||
deleted, err := c.manager.DeleteLogsBatch(ctx, cutoff, batchSize)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to delete old logs: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if deleted == 0 {
|
||||
// No more logs to delete
|
||||
break
|
||||
}
|
||||
|
||||
totalDeleted += deleted
|
||||
batchCount++
|
||||
c.logger.Debug("deleted batch %d: %d logs", batchCount, deleted)
|
||||
|
||||
// If we deleted fewer than the batch size, we're done
|
||||
if deleted < int64(batchSize) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if totalDeleted > 0 {
|
||||
c.logger.Info("log cleanup completed: deleted %d logs in %d batches", totalDeleted, batchCount)
|
||||
} else {
|
||||
c.logger.Debug("log cleanup completed: no old logs to delete")
|
||||
}
|
||||
}
|
||||
|
||||
// calculateNextRunDuration returns 24 hours plus a random jitter between 15-30 minutes
|
||||
func calculateNextRunDuration() time.Duration {
|
||||
jitter := minJitter + time.Duration(rand.Int63n(int64(maxJitter-minJitter)))
|
||||
return cleanupInterval + jitter
|
||||
}
|
||||
68
framework/logstore/config.go
Normal file
68
framework/logstore/config.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Package logstore provides a logs store for Bifrost.
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/framework/objectstore"
|
||||
)
|
||||
|
||||
// Config represents the configuration for the logs store.
|
||||
type Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type LogStoreType `json:"type"`
|
||||
RetentionDays int `json:"retention_days"`
|
||||
Config any `json:"config"`
|
||||
ObjectStorage *objectstore.Config `json:"object_storage,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON is the custom unmarshal logic for Config
|
||||
func (c *Config) UnmarshalJSON(data []byte) error {
|
||||
// First, unmarshal into a temporary struct to get the basic fields
|
||||
type TempConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type LogStoreType `json:"type"`
|
||||
Config json.RawMessage `json:"config"` // Keep as raw JSON
|
||||
RetentionDays int `json:"retention_days"`
|
||||
ObjectStorage *objectstore.Config `json:"object_storage,omitempty"`
|
||||
}
|
||||
|
||||
var temp TempConfig
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal logs config: %w", err)
|
||||
}
|
||||
|
||||
// Set basic fields
|
||||
c.Enabled = temp.Enabled
|
||||
c.Type = temp.Type
|
||||
c.RetentionDays = temp.RetentionDays
|
||||
c.ObjectStorage = temp.ObjectStorage
|
||||
if !temp.Enabled {
|
||||
c.Config = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the config field based on type
|
||||
switch temp.Type {
|
||||
case LogStoreTypeSQLite:
|
||||
if len(temp.Config) == 0 {
|
||||
return fmt.Errorf("missing sqlite config payload")
|
||||
}
|
||||
var sqliteConfig SQLiteConfig
|
||||
if err := json.Unmarshal(temp.Config, &sqliteConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal sqlite config: %w", err)
|
||||
}
|
||||
c.Config = &sqliteConfig
|
||||
case LogStoreTypePostgres:
|
||||
var postgresConfig PostgresConfig
|
||||
var err error
|
||||
if err = json.Unmarshal(temp.Config, &postgresConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal postgres config: %w", err)
|
||||
}
|
||||
c.Config = &postgresConfig
|
||||
default:
|
||||
return fmt.Errorf("unknown log store type: %s", temp.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
8
framework/logstore/errors.go
Normal file
8
framework/logstore/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package logstore
|
||||
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
ErrNotFound = fmt.Errorf("log not found")
|
||||
ErrJobInternal = fmt.Errorf("internal job store error")
|
||||
)
|
||||
613
framework/logstore/hybrid.go
Normal file
613
framework/logstore/hybrid.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/objectstore"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUploadWorkers = 10
|
||||
defaultUploadQueueSize = 5000
|
||||
maxContentSummaryBytes = 2048
|
||||
defaultMaxUploadQueueBytes = 1 << 30 // 1 GiB
|
||||
)
|
||||
|
||||
// uploadWork represents an async S3 upload job.
|
||||
type uploadWork struct {
|
||||
logID string
|
||||
timestamp time.Time
|
||||
payload []byte // JSON-encoded payload
|
||||
tags map[string]string
|
||||
}
|
||||
|
||||
// HybridLogStore wraps an existing LogStore and offloads large payload
|
||||
// fields to object storage while keeping a lightweight index in the DB.
|
||||
//
|
||||
// Method routing:
|
||||
// - Delegated directly (40+ methods): all analytics, search, histogram, ranking,
|
||||
// distinct, MCP, async job methods
|
||||
// - Intercepted: Create, CreateIfNotExists, BatchCreateIfNotExists, FindByID,
|
||||
// Update, DeleteLog, DeleteLogs, DeleteLogsBatch, Close
|
||||
type HybridLogStore struct {
|
||||
inner LogStore
|
||||
objects objectstore.ObjectStore
|
||||
prefix string
|
||||
logger schemas.Logger
|
||||
uploadQueue chan *uploadWork
|
||||
wg sync.WaitGroup
|
||||
closed atomic.Bool
|
||||
droppedUploads atomic.Int64
|
||||
pendingBytes atomic.Int64
|
||||
}
|
||||
|
||||
// newHybridLogStore creates a HybridLogStore wrapping the given inner store.
|
||||
func newHybridLogStore(inner LogStore, objects objectstore.ObjectStore, prefix string, logger schemas.Logger) *HybridLogStore {
|
||||
h := &HybridLogStore{
|
||||
inner: inner,
|
||||
objects: objects,
|
||||
prefix: prefix,
|
||||
logger: logger,
|
||||
uploadQueue: make(chan *uploadWork, defaultUploadQueueSize),
|
||||
}
|
||||
// Start upload workers.
|
||||
for i := 0; i < defaultUploadWorkers; i++ {
|
||||
h.wg.Add(1)
|
||||
go h.uploadWorker()
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// uploadWorker processes async S3 upload jobs from the queue.
|
||||
func (h *HybridLogStore) uploadWorker() {
|
||||
defer h.wg.Done()
|
||||
for work := range h.uploadQueue {
|
||||
h.processUpload(work)
|
||||
}
|
||||
}
|
||||
|
||||
// processUpload uploads a single payload to object storage.
|
||||
// This is fire-and-forget by design: on Put failure the upload is dropped and
|
||||
// counted in droppedUploads. The DB row retains has_object=false, so FindByID
|
||||
// falls back to whatever data the DB holds. Retries are intentionally omitted
|
||||
// to keep S3 latency from cascading into the write path.
|
||||
func (h *HybridLogStore) processUpload(work *uploadWork) {
|
||||
payloadSize := int64(len(work.payload))
|
||||
defer h.pendingBytes.Add(-payloadSize)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
h.logger.Error("objectstore: panic in upload worker (recovered): %v", r)
|
||||
h.droppedUploads.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
key := ObjectKey(h.prefix, work.timestamp, work.logID)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.objects.Put(ctx, key, work.payload, work.tags); err != nil {
|
||||
h.logger.Warn("objectstore: failed to upload log %s: %v", work.logID, err)
|
||||
h.droppedUploads.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the DB row as having an object. Use a fresh context so that a slow
|
||||
// Put doesn't starve the DB update of its deadline. Retry up to 3 times
|
||||
// with exponential backoff to avoid orphaning the uploaded object.
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
dbCtx, dbCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
err := h.inner.Update(dbCtx, work.logID, map[string]interface{}{"has_object": true})
|
||||
dbCancel()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
h.logger.Warn("objectstore: failed to set has_object for log %s (attempt %d/3): %v", work.logID, attempt+1, err)
|
||||
if attempt < 2 {
|
||||
time.Sleep(time.Duration(1<<attempt) * time.Second) // 1s, 2s backoff
|
||||
}
|
||||
}
|
||||
h.logger.Error("objectstore: failed to set has_object for log %s after 3 attempts; payload orphaned in object store", work.logID)
|
||||
h.droppedUploads.Add(1)
|
||||
}
|
||||
|
||||
// isPayloadEmpty returns true when every value in the payload map is empty.
|
||||
// Skipping uploads for empty payloads avoids wasted S3 PUTs (e.g. initial
|
||||
// "processing" entries that carry no input/output data yet).
|
||||
func isPayloadEmpty(payload map[string]string) bool {
|
||||
for _, v := range payload {
|
||||
if v != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// enqueueUpload pushes an upload job onto the queue. If the queue is full,
|
||||
// the job is dropped to prevent S3 slowness from cascading.
|
||||
func (h *HybridLogStore) enqueueUpload(logID string, timestamp time.Time, payload map[string]string, tags map[string]string) {
|
||||
if h.closed.Load() || isPayloadEmpty(payload) {
|
||||
return
|
||||
}
|
||||
// Recover from send-on-closed-channel panic: Close() may interleave
|
||||
// between the closed check above and the channel send below.
|
||||
// Same pattern as plugins/logging/writer.go enqueueLogEntry.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
h.droppedUploads.Add(1)
|
||||
}
|
||||
}()
|
||||
data, err := sonic.Marshal(payload)
|
||||
if err != nil {
|
||||
h.logger.Warn("objectstore: failed to marshal payload for log %s: %v", logID, err)
|
||||
h.droppedUploads.Add(1)
|
||||
return
|
||||
}
|
||||
if h.pendingBytes.Load()+int64(len(data)) > defaultMaxUploadQueueBytes {
|
||||
h.droppedUploads.Add(1)
|
||||
h.logger.Warn("objectstore: upload queue memory limit reached, dropping upload for log %s", logID)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case h.uploadQueue <- &uploadWork{
|
||||
logID: logID,
|
||||
timestamp: timestamp,
|
||||
payload: data,
|
||||
tags: tags,
|
||||
}:
|
||||
h.pendingBytes.Add(int64(len(data)))
|
||||
default:
|
||||
h.droppedUploads.Add(1)
|
||||
h.logger.Warn("objectstore: upload queue full, dropping upload for log %s", logID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Intercepted methods ---
|
||||
|
||||
// prepareDBEntry builds the lightweight DB entry by extracting the content
|
||||
// summary, trimming input history to the last user message, and clearing
|
||||
// payload fields. Must be called after SerializeFields() populates the
|
||||
// Parsed fields.
|
||||
func prepareDBEntry(dbEntry *Log) {
|
||||
idx := findLastUserMessageIndex(dbEntry.InputHistoryParsed)
|
||||
|
||||
// Content summary: extract text from the found user message.
|
||||
// Falls back to BuildInputContentSummary for non-chat inputs (speech, image, etc.).
|
||||
if idx >= 0 {
|
||||
dbEntry.ContentSummary = extractChatMessageText(&dbEntry.InputHistoryParsed[idx])
|
||||
} else {
|
||||
dbEntry.ContentSummary = dbEntry.BuildInputContentSummary()
|
||||
}
|
||||
// Bound content summary to prevent large prompts from bloating the DB row.
|
||||
dbEntry.ContentSummary = truncateTag(dbEntry.ContentSummary, maxContentSummaryBytes)
|
||||
|
||||
// Serialize last user message before ClearPayload zeros everything.
|
||||
// msgs[idx:idx+1] reuses the backing array — no heap alloc, no struct copy.
|
||||
var lastUserMessage string
|
||||
if idx >= 0 {
|
||||
lastUserMessage, _ = sonic.MarshalString(dbEntry.InputHistoryParsed[idx : idx+1])
|
||||
}
|
||||
|
||||
ClearPayload(dbEntry)
|
||||
|
||||
// Restore last user message so list queries can display it without S3.
|
||||
dbEntry.InputHistory = lastUserMessage
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) Create(ctx context.Context, entry *Log) error {
|
||||
if err := entry.SerializeFields(); err != nil {
|
||||
return fmt.Errorf("logstore: serialize before extract: %w", err)
|
||||
}
|
||||
payload := ExtractPayload(entry)
|
||||
tags := BuildTags(entry)
|
||||
// Work on a shallow copy so the caller's entry is preserved on DB failure.
|
||||
dbEntry := *entry
|
||||
prepareDBEntry(&dbEntry)
|
||||
if err := h.inner.Create(ctx, &dbEntry); err != nil {
|
||||
return err
|
||||
}
|
||||
entry.ContentSummary = dbEntry.ContentSummary
|
||||
h.enqueueUpload(entry.ID, entry.Timestamp, payload, tags)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) CreateIfNotExists(ctx context.Context, entry *Log) error {
|
||||
if err := entry.SerializeFields(); err != nil {
|
||||
return fmt.Errorf("logstore: serialize before extract: %w", err)
|
||||
}
|
||||
payload := ExtractPayload(entry)
|
||||
tags := BuildTags(entry)
|
||||
// Work on a shallow copy so the caller's entry is preserved on DB failure.
|
||||
dbEntry := *entry
|
||||
prepareDBEntry(&dbEntry)
|
||||
if err := h.inner.CreateIfNotExists(ctx, &dbEntry); err != nil {
|
||||
return err
|
||||
}
|
||||
entry.ContentSummary = dbEntry.ContentSummary
|
||||
h.enqueueUpload(entry.ID, entry.Timestamp, payload, tags)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) BatchCreateIfNotExists(ctx context.Context, entries []*Log) error {
|
||||
type pendingUpload struct {
|
||||
logID string
|
||||
timestamp time.Time
|
||||
payload map[string]string
|
||||
tags map[string]string
|
||||
}
|
||||
var uploads []pendingUpload
|
||||
|
||||
dbEntries := make([]*Log, len(entries))
|
||||
for i, entry := range entries {
|
||||
if err := entry.SerializeFields(); err != nil {
|
||||
return fmt.Errorf("logstore: serialize before extract: %w", err)
|
||||
}
|
||||
payload := ExtractPayload(entry)
|
||||
tags := BuildTags(entry)
|
||||
// Work on a shallow copy so the caller's entries are preserved on DB failure.
|
||||
dbEntry := *entry
|
||||
prepareDBEntry(&dbEntry)
|
||||
dbEntries[i] = &dbEntry
|
||||
uploads = append(uploads, pendingUpload{
|
||||
logID: entry.ID,
|
||||
timestamp: entry.Timestamp,
|
||||
payload: payload,
|
||||
tags: tags,
|
||||
})
|
||||
}
|
||||
|
||||
if err := h.inner.BatchCreateIfNotExists(ctx, dbEntries); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, entry := range entries {
|
||||
entry.ContentSummary = dbEntries[i].ContentSummary
|
||||
}
|
||||
|
||||
for _, u := range uploads {
|
||||
h.enqueueUpload(u.logID, u.timestamp, u.payload, u.tags)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FindByID(ctx context.Context, id string) (*Log, error) {
|
||||
log, err := h.inner.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.hydrateLog(ctx, log)
|
||||
return log, nil
|
||||
}
|
||||
|
||||
// hydrateLog fetches the offloaded payload from object storage and merges it
|
||||
// back into the Log struct. It is a no-op when HasObject is false.
|
||||
//
|
||||
// When requestedFields is non-empty, only the payload fields present in that
|
||||
// projection are kept after merge — unrequested payload fields are cleared to
|
||||
// honour projection semantics and avoid pulling large blobs unnecessarily.
|
||||
func (h *HybridLogStore) hydrateLog(ctx context.Context, log *Log, requestedFields ...string) {
|
||||
if log == nil || !log.HasObject {
|
||||
return
|
||||
}
|
||||
key := ObjectKey(h.prefix, log.Timestamp, log.ID)
|
||||
data, err := h.objects.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.logger.Warn("objectstore: failed to fetch payload for log %s: %v", log.ID, err)
|
||||
return // Graceful degradation
|
||||
}
|
||||
if mergeErr := MergePayloadFromJSON(log, data); mergeErr != nil {
|
||||
h.logger.Warn("objectstore: failed to merge payload for log %s: %v", log.ID, mergeErr)
|
||||
return
|
||||
}
|
||||
pruneUnrequestedPayloadFields(log, requestedFields)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) Update(ctx context.Context, id string, entry any) error {
|
||||
// Pass through to inner store for index field updates.
|
||||
// Payload fields in the update map are handled separately by the logging plugin.
|
||||
return h.inner.Update(ctx, id, entry)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) DeleteLog(ctx context.Context, id string) error {
|
||||
log, findErr := h.inner.FindByID(ctx, id)
|
||||
if findErr != nil && !errors.Is(findErr, ErrNotFound) {
|
||||
return findErr
|
||||
}
|
||||
if err := h.inner.DeleteLog(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
if log != nil && log.HasObject {
|
||||
key := ObjectKey(h.prefix, log.Timestamp, log.ID)
|
||||
if delErr := h.objects.Delete(ctx, key); delErr != nil {
|
||||
h.logger.Warn("objectstore: failed to delete object for log %s: %v", id, delErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) DeleteLogs(ctx context.Context, ids []string) error {
|
||||
// Collect keys for S3 deletion before removing from DB.
|
||||
var keys []string
|
||||
for _, id := range ids {
|
||||
log, findErr := h.inner.FindByID(ctx, id)
|
||||
if findErr != nil && !errors.Is(findErr, ErrNotFound) {
|
||||
return findErr
|
||||
}
|
||||
if log != nil && log.HasObject {
|
||||
keys = append(keys, ObjectKey(h.prefix, log.Timestamp, log.ID))
|
||||
}
|
||||
}
|
||||
if err := h.inner.DeleteLogs(ctx, ids); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if delErr := h.objects.DeleteBatch(ctx, keys); delErr != nil {
|
||||
h.logger.Warn("objectstore: failed to batch delete %d objects: %v", len(keys), delErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) DeleteLogsBatch(ctx context.Context, cutoff time.Time, batchSize int) (int64, error) {
|
||||
// Delegate to inner — S3 objects will be cleaned up by lifecycle policies.
|
||||
return h.inner.DeleteLogsBatch(ctx, cutoff, batchSize)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) Close(ctx context.Context) error {
|
||||
h.closed.Store(true)
|
||||
close(h.uploadQueue)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
h.logger.Warn("objectstore: shutdown cancelled before upload queue drained: %v", ctx.Err())
|
||||
// Still wait for workers to finish so we don't close dependencies mid-flight.
|
||||
<-done
|
||||
}
|
||||
if err := h.objects.Close(); err != nil {
|
||||
h.logger.Warn("objectstore: error closing object store: %v", err)
|
||||
}
|
||||
return h.inner.Close(ctx)
|
||||
}
|
||||
|
||||
// DroppedUploads returns the number of S3 uploads that were dropped.
|
||||
func (h *HybridLogStore) DroppedUploads() int64 {
|
||||
return h.droppedUploads.Load()
|
||||
}
|
||||
|
||||
// --- Delegated methods (pass through to inner store unchanged) ---
|
||||
|
||||
func (h *HybridLogStore) Ping(ctx context.Context) error {
|
||||
return h.inner.Ping(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FindFirst(ctx context.Context, query any, fields ...string) (*Log, error) {
|
||||
needsHydration := len(fields) == 0 || fieldsNeedHydration(fields)
|
||||
if needsHydration && len(fields) > 0 {
|
||||
fields = ensureHydrationFields(fields)
|
||||
}
|
||||
log, err := h.inner.FindFirst(ctx, query, fields...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsHydration {
|
||||
h.hydrateLog(ctx, log, fields...)
|
||||
}
|
||||
return log, nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FindAll(ctx context.Context, query any, fields ...string) ([]*Log, error) {
|
||||
needsHydration := len(fields) == 0 || fieldsNeedHydration(fields)
|
||||
if needsHydration && len(fields) > 0 {
|
||||
fields = ensureHydrationFields(fields)
|
||||
}
|
||||
logs, err := h.inner.FindAll(ctx, query, fields...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsHydration {
|
||||
for _, log := range logs {
|
||||
h.hydrateLog(ctx, log, fields...)
|
||||
}
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FindAllDistinct(ctx context.Context, query any, fields ...string) ([]*Log, error) {
|
||||
return h.inner.FindAllDistinct(ctx, query, fields...)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) HasLogs(ctx context.Context) (bool, error) {
|
||||
return h.inner.HasLogs(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) SearchLogs(ctx context.Context, filters SearchFilters, pagination PaginationOptions) (*SearchResult, error) {
|
||||
return h.inner.SearchLogs(ctx, filters, pagination)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error) {
|
||||
return h.inner.GetSessionLogs(ctx, sessionID, pagination)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error) {
|
||||
return h.inner.GetSessionSummary(ctx, sessionID)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error) {
|
||||
return h.inner.GetStats(ctx, filters)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error) {
|
||||
return h.inner.GetHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error) {
|
||||
return h.inner.GetTokenHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*CostHistogramResult, error) {
|
||||
return h.inner.GetCostHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetModelHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ModelHistogramResult, error) {
|
||||
return h.inner.GetModelHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*LatencyHistogramResult, error) {
|
||||
return h.inner.GetLatencyHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetProviderCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderCostHistogramResult, error) {
|
||||
return h.inner.GetProviderCostHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetProviderTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderTokenHistogramResult, error) {
|
||||
return h.inner.GetProviderTokenHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetProviderLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderLatencyHistogramResult, error) {
|
||||
return h.inner.GetProviderLatencyHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetModelRankings(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error) {
|
||||
return h.inner.GetModelRankings(ctx, filters)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetUserRankings(ctx context.Context, filters SearchFilters) (*UserRankingResult, error) {
|
||||
return h.inner.GetUserRankings(ctx, filters)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDimensionCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error) {
|
||||
return h.inner.GetDimensionCostHistogram(ctx, filters, bucketSizeSeconds, dimension)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDimensionTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error) {
|
||||
return h.inner.GetDimensionTokenHistogram(ctx, filters, bucketSizeSeconds, dimension)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDimensionLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error) {
|
||||
return h.inner.GetDimensionLatencyHistogram(ctx, filters, bucketSizeSeconds, dimension)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) BulkUpdateCost(ctx context.Context, updates map[string]float64) error {
|
||||
return h.inner.BulkUpdateCost(ctx, updates)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) Flush(ctx context.Context, since time.Time) error {
|
||||
return h.inner.Flush(ctx, since)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) IsLogEntryPresent(ctx context.Context, id string) (bool, error) {
|
||||
return h.inner.IsLogEntryPresent(ctx, id)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDistinctAliases(ctx context.Context) ([]string, error) {
|
||||
return h.inner.GetDistinctAliases(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDistinctModels(ctx context.Context) ([]string, error) {
|
||||
return h.inner.GetDistinctModels(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error) {
|
||||
return h.inner.GetDistinctKeyPairs(ctx, idCol, nameCol)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDistinctRoutingEngines(ctx context.Context) ([]string, error) {
|
||||
return h.inner.GetDistinctRoutingEngines(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetDistinctMetadataKeys(ctx context.Context) (map[string][]string, error) {
|
||||
return h.inner.GetDistinctMetadataKeys(ctx)
|
||||
}
|
||||
|
||||
// MCP Tool Log methods — delegated directly.
|
||||
|
||||
func (h *HybridLogStore) GetMCPHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPHistogramResult, error) {
|
||||
return h.inner.GetMCPHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetMCPCostHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPCostHistogramResult, error) {
|
||||
return h.inner.GetMCPCostHistogram(ctx, filters, bucketSizeSeconds)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetMCPTopTools(ctx context.Context, filters MCPToolLogSearchFilters, limit int) (*MCPTopToolsResult, error) {
|
||||
return h.inner.GetMCPTopTools(ctx, filters, limit)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) CreateMCPToolLog(ctx context.Context, entry *MCPToolLog) error {
|
||||
return h.inner.CreateMCPToolLog(ctx, entry)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FindMCPToolLog(ctx context.Context, id string) (*MCPToolLog, error) {
|
||||
return h.inner.FindMCPToolLog(ctx, id)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) UpdateMCPToolLog(ctx context.Context, id string, entry any) error {
|
||||
return h.inner.UpdateMCPToolLog(ctx, id, entry)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) SearchMCPToolLogs(ctx context.Context, filters MCPToolLogSearchFilters, pagination PaginationOptions) (*MCPToolLogSearchResult, error) {
|
||||
return h.inner.SearchMCPToolLogs(ctx, filters, pagination)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetMCPToolLogStats(ctx context.Context, filters MCPToolLogSearchFilters) (*MCPToolLogStats, error) {
|
||||
return h.inner.GetMCPToolLogStats(ctx, filters)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) HasMCPToolLogs(ctx context.Context) (bool, error) {
|
||||
return h.inner.HasMCPToolLogs(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) DeleteMCPToolLogs(ctx context.Context, ids []string) error {
|
||||
return h.inner.DeleteMCPToolLogs(ctx, ids)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FlushMCPToolLogs(ctx context.Context, since time.Time) error {
|
||||
return h.inner.FlushMCPToolLogs(ctx, since)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetAvailableToolNames(ctx context.Context) ([]string, error) {
|
||||
return h.inner.GetAvailableToolNames(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetAvailableServerLabels(ctx context.Context) ([]string, error) {
|
||||
return h.inner.GetAvailableServerLabels(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) GetAvailableMCPVirtualKeys(ctx context.Context) ([]MCPToolLog, error) {
|
||||
return h.inner.GetAvailableMCPVirtualKeys(ctx)
|
||||
}
|
||||
|
||||
// Async Job methods — delegated directly.
|
||||
|
||||
func (h *HybridLogStore) CreateAsyncJob(ctx context.Context, job *AsyncJob) error {
|
||||
return h.inner.CreateAsyncJob(ctx, job)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) FindAsyncJobByID(ctx context.Context, id string) (*AsyncJob, error) {
|
||||
return h.inner.FindAsyncJobByID(ctx, id)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) UpdateAsyncJob(ctx context.Context, id string, updates map[string]interface{}) error {
|
||||
return h.inner.UpdateAsyncJob(ctx, id, updates)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) DeleteExpiredAsyncJobs(ctx context.Context) (int64, error) {
|
||||
return h.inner.DeleteExpiredAsyncJobs(ctx)
|
||||
}
|
||||
|
||||
func (h *HybridLogStore) DeleteStaleAsyncJobs(ctx context.Context, staleSince time.Time) (int64, error) {
|
||||
return h.inner.DeleteStaleAsyncJobs(ctx, staleSince)
|
||||
}
|
||||
332
framework/logstore/hybrid_test.go
Normal file
332
framework/logstore/hybrid_test.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/objectstore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type hybridTestLogger struct{}
|
||||
|
||||
func (hybridTestLogger) Debug(string, ...any) {}
|
||||
func (hybridTestLogger) Info(string, ...any) {}
|
||||
func (hybridTestLogger) Warn(string, ...any) {}
|
||||
func (hybridTestLogger) Error(string, ...any) {}
|
||||
func (hybridTestLogger) Fatal(string, ...any) {}
|
||||
func (hybridTestLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (hybridTestLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (hybridTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func newTestHybrid(t *testing.T) (*HybridLogStore, LogStore, *objectstore.InMemoryObjectStore) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create SQLite inner store.
|
||||
inner, err := newSqliteLogStore(ctx, &SQLiteConfig{Path: ":memory:"}, hybridTestLogger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
objStore := objectstore.NewInMemoryObjectStore()
|
||||
hybrid := newHybridLogStore(inner, objStore, "test", hybridTestLogger{})
|
||||
return hybrid, inner, objStore
|
||||
}
|
||||
|
||||
func waitForUploads(t *testing.T, done func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if done() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timed out waiting for upload state")
|
||||
}
|
||||
|
||||
func TestHybrid_CreateAndFindByID(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
inputContent := "Hello, how are you?"
|
||||
entry := &Log{
|
||||
ID: "log-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-sonnet",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistoryParsed: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &inputContent}},
|
||||
},
|
||||
OutputMessageParsed: &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{ContentStr: strPtr("I'm fine, thanks!")},
|
||||
},
|
||||
}
|
||||
|
||||
// Serialize fields so TEXT columns are populated (simulating what GORM BeforeCreate does).
|
||||
require.NoError(t, entry.SerializeFields())
|
||||
|
||||
err := hybrid.CreateIfNotExists(ctx, entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
waitForUploads(t, func() bool { return objStore.Len() == 1 })
|
||||
|
||||
// Verify object was uploaded.
|
||||
assert.Equal(t, 1, objStore.Len(), "expected 1 object in store")
|
||||
|
||||
// FindByID should return hydrated log with payload.
|
||||
found, err := hybrid.FindByID(ctx, "log-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "log-1", found.ID)
|
||||
assert.True(t, found.HasObject)
|
||||
assert.NotEmpty(t, found.InputHistory, "InputHistory should be hydrated from S3")
|
||||
assert.NotEmpty(t, found.OutputMessage, "OutputMessage should be hydrated from S3")
|
||||
|
||||
// Content summary should contain input text but the output should be in the payload.
|
||||
assert.Contains(t, found.ContentSummary, "Hello, how are you?")
|
||||
}
|
||||
|
||||
func TestHybrid_EmptyPayloadSkipsUpload(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
entry := &Log{
|
||||
ID: "log-processing",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Status: "processing",
|
||||
Object: "chat.completion",
|
||||
}
|
||||
|
||||
err := hybrid.CreateIfNotExists(ctx, entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
waitForUploads(t, func() bool { return len(hybrid.uploadQueue) == 0 })
|
||||
|
||||
// No upload when all payload fields are empty (e.g. initial "processing" entries).
|
||||
assert.Equal(t, 0, objStore.Len(), "empty-payload entries should not be uploaded")
|
||||
}
|
||||
|
||||
func TestHybrid_BatchCreateIfNotExists(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
entries := make([]*Log, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
content := "input message"
|
||||
entries[i] = &Log{
|
||||
ID: "batch-" + string(rune('a'+i)),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistoryParsed: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
|
||||
},
|
||||
}
|
||||
require.NoError(t, entries[i].SerializeFields())
|
||||
}
|
||||
|
||||
err := hybrid.BatchCreateIfNotExists(ctx, entries)
|
||||
require.NoError(t, err)
|
||||
|
||||
waitForUploads(t, func() bool { return objStore.Len() == 3 })
|
||||
assert.Equal(t, 3, objStore.Len())
|
||||
}
|
||||
|
||||
func TestHybrid_FindByID_NoObject(t *testing.T) {
|
||||
hybrid, inner, _ := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert directly into inner store (simulating legacy data without object).
|
||||
entry := &Log{
|
||||
ID: "legacy-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistory: `[{"role":"user","content":"legacy input"}]`,
|
||||
HasObject: false,
|
||||
}
|
||||
require.NoError(t, inner.CreateIfNotExists(ctx, entry))
|
||||
|
||||
found, err := hybrid.FindByID(ctx, "legacy-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.HasObject)
|
||||
// Legacy data: payload is in DB.
|
||||
assert.NotEmpty(t, found.InputHistory)
|
||||
}
|
||||
|
||||
func TestHybrid_FindByID_GracefulDegradation(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
content := "test input"
|
||||
entry := &Log{
|
||||
ID: "degrade-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistoryParsed: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
|
||||
},
|
||||
}
|
||||
require.NoError(t, entry.SerializeFields())
|
||||
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
|
||||
waitForUploads(t, func() bool { return objStore.Len() == 1 })
|
||||
|
||||
// Simulate S3 failure.
|
||||
objStore.GetErr = assert.AnError
|
||||
|
||||
found, err := hybrid.FindByID(ctx, "degrade-1")
|
||||
require.NoError(t, err, "FindByID should succeed even when S3 fails")
|
||||
assert.True(t, found.HasObject)
|
||||
// When S3 fails, the DB data is returned. The DB retains the last message
|
||||
// in input_history for list views, so it won't be empty.
|
||||
assert.NotEmpty(t, found.InputHistory, "last message should be retained in DB")
|
||||
// But other payload fields (output_message, params, etc.) should be empty.
|
||||
assert.Empty(t, found.OutputMessage, "output should be empty when S3 fails")
|
||||
}
|
||||
|
||||
func TestHybrid_PutFailureDropsUpload(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate S3 write failure.
|
||||
objStore.PutErr = assert.AnError
|
||||
|
||||
content := "important input"
|
||||
entry := &Log{
|
||||
ID: "put-fail-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistoryParsed: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
|
||||
},
|
||||
}
|
||||
require.NoError(t, entry.SerializeFields())
|
||||
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
|
||||
waitForUploads(t, func() bool { return hybrid.DroppedUploads() == 1 })
|
||||
|
||||
// Upload should have been dropped.
|
||||
assert.Equal(t, 0, objStore.Len(), "no object should be stored when Put fails")
|
||||
assert.Equal(t, int64(1), hybrid.DroppedUploads(), "dropped upload should be counted")
|
||||
|
||||
// DB row exists but has_object remains false since the upload failed.
|
||||
found, err := hybrid.FindByID(ctx, "put-fail-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.HasObject, "has_object should remain false when upload fails")
|
||||
}
|
||||
|
||||
func TestHybrid_DeleteLog(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
entry := &Log{
|
||||
ID: "del-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistory: `[{"role":"user","content":"delete me"}]`,
|
||||
}
|
||||
require.NoError(t, entry.SerializeFields())
|
||||
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
|
||||
waitForUploads(t, func() bool { return objStore.Len() == 1 })
|
||||
assert.Equal(t, 1, objStore.Len())
|
||||
|
||||
err := hybrid.DeleteLog(ctx, "del-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Object should be deleted from S3.
|
||||
assert.Equal(t, 0, objStore.Len())
|
||||
|
||||
// DB should also be empty.
|
||||
_, err = hybrid.FindByID(ctx, "del-1")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHybrid_Tags(t *testing.T) {
|
||||
hybrid, _, objStore := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
ts := time.Date(2026, 4, 3, 14, 30, 0, 0, time.UTC)
|
||||
vkID := "vk_test"
|
||||
entry := &Log{
|
||||
ID: "tag-1",
|
||||
Timestamp: ts,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "error",
|
||||
Object: "chat.completion",
|
||||
VirtualKeyID: &vkID,
|
||||
Stream: true,
|
||||
InputHistory: `[{"role":"user","content":"test"}]`,
|
||||
}
|
||||
require.NoError(t, entry.SerializeFields())
|
||||
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
|
||||
waitForUploads(t, func() bool { return objStore.Len() == 1 })
|
||||
|
||||
key := ObjectKey("test", ts, "tag-1")
|
||||
tags := objStore.GetTags(key)
|
||||
assert.Equal(t, "anthropic", tags["provider"])
|
||||
assert.Equal(t, "error", tags["status"])
|
||||
assert.Equal(t, "true", tags["has_error"])
|
||||
assert.Equal(t, "true", tags["stream"])
|
||||
assert.Equal(t, "vk_test", tags["virtual_key_id"])
|
||||
assert.Equal(t, "2026-04-03", tags["date"])
|
||||
}
|
||||
|
||||
func TestHybrid_ContentSummaryIsInputOnly(t *testing.T) {
|
||||
hybrid, inner, _ := newTestHybrid(t)
|
||||
defer hybrid.Close(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
inputText := "What is the capital of France?"
|
||||
outputText := "The capital of France is Paris."
|
||||
entry := &Log{
|
||||
ID: "summary-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
InputHistoryParsed: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &inputText}},
|
||||
},
|
||||
OutputMessageParsed: &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &outputText},
|
||||
},
|
||||
}
|
||||
require.NoError(t, entry.SerializeFields())
|
||||
require.NoError(t, hybrid.CreateIfNotExists(ctx, entry))
|
||||
|
||||
// Read from inner DB to check content_summary.
|
||||
dbLog, err := inner.FindByID(ctx, "summary-1")
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, dbLog.ContentSummary, "capital of France")
|
||||
assert.NotContains(t, dbLog.ContentSummary, "Paris", "content_summary should not contain output text")
|
||||
}
|
||||
45
framework/logstore/logger.go
Normal file
45
framework/logstore/logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
gormLibLogger "gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// GormLogger is a logger for GORM.
|
||||
type gormLogger struct {
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// LogMode sets the log mode for the logger.
|
||||
func (l *gormLogger) LogMode(level gormLibLogger.LogLevel) gormLibLogger.Interface {
|
||||
// NOOP
|
||||
return l
|
||||
}
|
||||
|
||||
// Info logs an info message.
|
||||
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Info(msg, data...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message.
|
||||
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Warn(msg, data...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
l.logger.Error(msg, data...)
|
||||
}
|
||||
|
||||
// Trace logs a trace message.
|
||||
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
// NOOP
|
||||
}
|
||||
|
||||
// newGormLogger creates a new GormLogger.
|
||||
func newGormLogger(l schemas.Logger) *gormLogger {
|
||||
return &gormLogger{logger: l}
|
||||
}
|
||||
1188
framework/logstore/matviews.go
Normal file
1188
framework/logstore/matviews.go
Normal file
File diff suppressed because it is too large
Load Diff
2622
framework/logstore/migrations.go
Normal file
2622
framework/logstore/migrations.go
Normal file
File diff suppressed because it is too large
Load Diff
437
framework/logstore/migrations_test.go
Normal file
437
framework/logstore/migrations_test.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// postgresDSN matches the postgres service in tests/docker-compose.yml and
|
||||
// framework/docker-compose.yml.
|
||||
const postgresDSN = "host=localhost user=bifrost password=bifrost_password dbname=bifrost port=5432 sslmode=disable"
|
||||
|
||||
// trySetupPostgresDB attempts to connect to Postgres and returns the connection.
|
||||
// Returns nil if Postgres is unavailable.
|
||||
func trySetupPostgresDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify the connection is actually live before proceeding.
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// setupLogsTableForGINIndexTest creates the logs table in a pre-migration state
|
||||
// (with metadata column but without the GIN index) for testing the GIN index migration.
|
||||
func setupLogsTableForGINIndexTest(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
// Drop existing tables and migration tracking in the correct order.
|
||||
// Preserve the shared migrations table — only clear its rows.
|
||||
db.Exec("DROP INDEX IF EXISTS idx_logs_metadata_gin")
|
||||
db.Exec("DROP TABLE IF EXISTS logs")
|
||||
db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)")
|
||||
db.Exec("DELETE FROM migrations")
|
||||
|
||||
// Create a minimal logs table with only the columns needed for the test
|
||||
err := db.Exec(`
|
||||
CREATE TABLE logs (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
timestamp TIMESTAMP NOT NULL,
|
||||
object_type VARCHAR(255) NOT NULL,
|
||||
provider VARCHAR(255) NOT NULL,
|
||||
model VARCHAR(255) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMP NOT NULL
|
||||
)
|
||||
`).Error
|
||||
require.NoError(t, err, "Failed to create logs table")
|
||||
|
||||
// The migrator will create the migrations table automatically when it runs
|
||||
|
||||
// Clean up tables after the test
|
||||
t.Cleanup(func() {
|
||||
db.Exec("DROP INDEX IF EXISTS idx_logs_metadata_gin")
|
||||
db.Exec("DROP TABLE IF EXISTS logs")
|
||||
db.Exec("DELETE FROM migrations")
|
||||
})
|
||||
}
|
||||
|
||||
// insertTestLog inserts a test log entry with the given metadata value.
|
||||
func insertTestLog(t *testing.T, db *gorm.DB, id string, metadata *string) {
|
||||
t.Helper()
|
||||
now := time.Now()
|
||||
|
||||
var metadataVal interface{}
|
||||
if metadata != nil {
|
||||
metadataVal = *metadata
|
||||
}
|
||||
|
||||
err := db.Exec(`
|
||||
INSERT INTO logs (id, timestamp, object_type, provider, model, status, metadata, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, id, now, "chat_completion", "openai", "gpt-4", "success", metadataVal, now).Error
|
||||
require.NoError(t, err, "Failed to insert test log %s", id)
|
||||
}
|
||||
|
||||
// getMetadataValue retrieves the metadata value for a given log ID.
|
||||
func getMetadataValue(t *testing.T, db *gorm.DB, id string) *string {
|
||||
t.Helper()
|
||||
var result struct {
|
||||
Metadata *string
|
||||
}
|
||||
err := db.Table("logs").Select("metadata").Where("id = ?", id).Scan(&result).Error
|
||||
require.NoError(t, err, "Failed to get metadata for log %s", id)
|
||||
return result.Metadata
|
||||
}
|
||||
|
||||
// indexExists checks if the GIN index exists on the logs table.
|
||||
func indexExists(t *testing.T, db *gorm.DB, indexName string) bool {
|
||||
t.Helper()
|
||||
var count int64
|
||||
err := db.Raw(`
|
||||
SELECT COUNT(*) FROM pg_indexes
|
||||
WHERE tablename = 'logs' AND indexname = ?
|
||||
`, indexName).Scan(&count).Error
|
||||
require.NoError(t, err, "Failed to check index existence")
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func TestMigrationAddMetadataGINIndex_ValidJSON(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
setupLogsTableForGINIndexTest(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert logs with valid JSON object metadata (arrays are not supported)
|
||||
validJSON1 := `{"key": "value"}`
|
||||
validJSON2 := `{"nested": {"foo": "bar"}, "array": [1, 2, 3]}`
|
||||
validJSON3 := `{"empty": {}}`
|
||||
validJSON4 := `{"number": 42, "bool": true, "null": null}`
|
||||
|
||||
insertTestLog(t, db, "log-valid-1", &validJSON1)
|
||||
insertTestLog(t, db, "log-valid-2", &validJSON2)
|
||||
insertTestLog(t, db, "log-valid-3", &validJSON3)
|
||||
insertTestLog(t, db, "log-valid-4", &validJSON4)
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL DB: %v", err)
|
||||
}
|
||||
conn, err := sqlDB.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL connection: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
// Run the migration (cleanup only) then ensure the index is built.
|
||||
err = migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "Migration should succeed")
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "GIN index creation should succeed")
|
||||
|
||||
// Verify all valid JSON object values are preserved
|
||||
meta1 := getMetadataValue(t, db, "log-valid-1")
|
||||
assert.NotNil(t, meta1, "Valid JSON object should be preserved")
|
||||
assert.Equal(t, validJSON1, *meta1)
|
||||
|
||||
meta2 := getMetadataValue(t, db, "log-valid-2")
|
||||
assert.NotNil(t, meta2, "Valid JSON object should be preserved")
|
||||
assert.Equal(t, validJSON2, *meta2)
|
||||
|
||||
meta3 := getMetadataValue(t, db, "log-valid-3")
|
||||
assert.NotNil(t, meta3, "Valid JSON object with nested empty object should be preserved")
|
||||
assert.Equal(t, validJSON3, *meta3)
|
||||
|
||||
meta4 := getMetadataValue(t, db, "log-valid-4")
|
||||
assert.NotNil(t, meta4, "Valid JSON object with various types should be preserved")
|
||||
assert.Equal(t, validJSON4, *meta4)
|
||||
|
||||
// Verify the GIN index was created
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
|
||||
}
|
||||
|
||||
func TestMigrationAddMetadataGINIndex_InvalidJSON(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
setupLogsTableForGINIndexTest(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert logs with invalid JSON metadata (not valid JSON objects)
|
||||
invalid1 := `{"key": invalid}` // Unquoted value
|
||||
invalid2 := `{key: "value"}` // Unquoted key
|
||||
invalid3 := `{"key": "value",}` // Trailing comma
|
||||
invalid4 := `just a string` // Plain text
|
||||
invalid5 := `` // Empty string
|
||||
invalid6 := `{"unclosed": "brace"` // Unclosed brace
|
||||
invalid7 := `{"key": undefined}` // JavaScript undefined
|
||||
invalid8 := `{'single': 'quotes'}` // Single quotes
|
||||
invalid9 := `[NULL]` // Literal string [NULL] (not valid JSON)
|
||||
invalid10 := `NULL` // Literal string NULL (not valid JSON)
|
||||
invalid11 := `null` // Valid JSON but not a JSON object
|
||||
invalid12 := `[1, 2, 3]` // Valid JSON array but not a JSON object
|
||||
|
||||
insertTestLog(t, db, "log-invalid-1", &invalid1)
|
||||
insertTestLog(t, db, "log-invalid-2", &invalid2)
|
||||
insertTestLog(t, db, "log-invalid-3", &invalid3)
|
||||
insertTestLog(t, db, "log-invalid-4", &invalid4)
|
||||
insertTestLog(t, db, "log-invalid-5", &invalid5)
|
||||
insertTestLog(t, db, "log-invalid-6", &invalid6)
|
||||
insertTestLog(t, db, "log-invalid-7", &invalid7)
|
||||
insertTestLog(t, db, "log-invalid-8", &invalid8)
|
||||
insertTestLog(t, db, "log-invalid-9", &invalid9)
|
||||
insertTestLog(t, db, "log-invalid-10", &invalid10)
|
||||
insertTestLog(t, db, "log-invalid-11", &invalid11)
|
||||
insertTestLog(t, db, "log-invalid-12", &invalid12)
|
||||
insertTestLog(t, db, "log-actual-null", nil) // Actual SQL NULL
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL DB: %v", err)
|
||||
}
|
||||
conn, err := sqlDB.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL connection: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
// Run the migration (cleanup only) then ensure the index is built.
|
||||
err = migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "Migration should succeed even with invalid JSON")
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "GIN index creation should succeed after invalid JSON cleanup")
|
||||
|
||||
// Verify all non-object values were set to NULL (only JSON objects are supported)
|
||||
for i := 1; i <= 12; i++ {
|
||||
id := fmt.Sprintf("log-invalid-%d", i)
|
||||
meta := getMetadataValue(t, db, id)
|
||||
assert.Nil(t, meta, "Non-object JSON for %s should be set to NULL", id)
|
||||
}
|
||||
|
||||
// Verify actual SQL NULL remains NULL
|
||||
metaActualNull := getMetadataValue(t, db, "log-actual-null")
|
||||
assert.Nil(t, metaActualNull, "Actual NULL should remain NULL")
|
||||
|
||||
// Verify the GIN index was created
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
|
||||
}
|
||||
|
||||
func TestMigrationAddMetadataGINIndex_MixedData(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
setupLogsTableForGINIndexTest(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert a mix of valid JSON, invalid JSON, and NULL metadata
|
||||
validJSON := `{"environment": "production", "version": "1.0.0"}`
|
||||
invalidJSON := `{"broken": invalid_value}`
|
||||
|
||||
insertTestLog(t, db, "log-mixed-valid", &validJSON)
|
||||
insertTestLog(t, db, "log-mixed-invalid", &invalidJSON)
|
||||
insertTestLog(t, db, "log-mixed-null", nil)
|
||||
|
||||
// Run the migration (cleanup only) then ensure the index is built.
|
||||
err := migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "Migration should succeed")
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL DB: %v", err)
|
||||
}
|
||||
conn, err := sqlDB.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL connection: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "GIN index creation should succeed")
|
||||
|
||||
// Verify valid JSON is preserved
|
||||
metaValid := getMetadataValue(t, db, "log-mixed-valid")
|
||||
assert.NotNil(t, metaValid, "Valid JSON should be preserved")
|
||||
assert.Equal(t, validJSON, *metaValid)
|
||||
|
||||
// Verify invalid JSON is cleaned to NULL
|
||||
metaInvalid := getMetadataValue(t, db, "log-mixed-invalid")
|
||||
assert.Nil(t, metaInvalid, "Invalid JSON should be set to NULL")
|
||||
|
||||
// Verify NULL remains NULL
|
||||
metaNull := getMetadataValue(t, db, "log-mixed-null")
|
||||
assert.Nil(t, metaNull, "NULL metadata should remain NULL")
|
||||
|
||||
// Verify the GIN index was created
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
|
||||
}
|
||||
|
||||
func TestMigrationAddMetadataGINIndex_Idempotent(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
setupLogsTableForGINIndexTest(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert a log with valid JSON
|
||||
validJSON := `{"test": "idempotent"}`
|
||||
insertTestLog(t, db, "log-idempotent", &validJSON)
|
||||
|
||||
// Run the migration (cleanup only) then ensure the index is built.
|
||||
err := migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "First migration should succeed")
|
||||
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL DB: %v", err)
|
||||
}
|
||||
conn, err := sqlDB.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL connection: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "GIN index creation should succeed")
|
||||
|
||||
// Verify index exists
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should exist after first migration")
|
||||
|
||||
// Verify metadata is preserved
|
||||
meta1 := getMetadataValue(t, db, "log-idempotent")
|
||||
assert.NotNil(t, meta1)
|
||||
assert.Equal(t, validJSON, *meta1)
|
||||
|
||||
// Run the migration second time (should be idempotent due to gomigrate tracking)
|
||||
err = migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "Second migration should succeed (idempotent)")
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "ensureMetadataGINIndex should be a no-op when index already exists")
|
||||
|
||||
// Verify index still exists
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should exist after second migration")
|
||||
|
||||
// Verify metadata is still preserved
|
||||
meta2 := getMetadataValue(t, db, "log-idempotent")
|
||||
assert.NotNil(t, meta2)
|
||||
assert.Equal(t, validJSON, *meta2)
|
||||
}
|
||||
|
||||
func TestMigrationAddMetadataGINIndex_EmptyTable(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
setupLogsTableForGINIndexTest(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Run the migration (cleanup only) then ensure the index is built.
|
||||
err := migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "Migration should succeed on empty table")
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL DB: %v", err)
|
||||
}
|
||||
conn, err := sqlDB.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL connection: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "GIN index creation should succeed on empty table")
|
||||
|
||||
// Verify the GIN index was created
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created even on empty table")
|
||||
}
|
||||
|
||||
func TestMigrationAddMetadataGINIndex_EdgeCases(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
setupLogsTableForGINIndexTest(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test edge cases that might be tricky (only JSON objects are supported)
|
||||
emptyObject := `{}`
|
||||
emptyArray := `[]` // Not a JSON object, should be nullified
|
||||
whitespaceJSON := ` {"key": "value"} ` // Valid JSON with surrounding whitespace
|
||||
unicodeJSON := `{"emoji": "🎉", "chinese": "中文"}`
|
||||
largeNumber := `{"bignum": 99999999999999999999}`
|
||||
scientificNotation := `{"sci": 1.23e10}`
|
||||
|
||||
insertTestLog(t, db, "log-edge-empty-obj", &emptyObject)
|
||||
insertTestLog(t, db, "log-edge-empty-arr", &emptyArray)
|
||||
insertTestLog(t, db, "log-edge-whitespace", &whitespaceJSON)
|
||||
insertTestLog(t, db, "log-edge-unicode", &unicodeJSON)
|
||||
insertTestLog(t, db, "log-edge-large-num", &largeNumber)
|
||||
insertTestLog(t, db, "log-edge-scientific", &scientificNotation)
|
||||
|
||||
// Run the migration (cleanup only) then ensure the index is built.
|
||||
err := migrationAddMetadataGINIndex(ctx, db)
|
||||
require.NoError(t, err, "Migration should succeed")
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL DB: %v", err)
|
||||
}
|
||||
conn, err := sqlDB.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get SQL connection: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
err = ensureMetadataGINIndex(ctx, conn)
|
||||
require.NoError(t, err, "GIN index creation should succeed")
|
||||
|
||||
// Verify all edge cases are handled correctly
|
||||
// Empty object should be preserved, but empty array is not a JSON object
|
||||
assert.NotNil(t, getMetadataValue(t, db, "log-edge-empty-obj"), "Empty object should be preserved")
|
||||
assert.Nil(t, getMetadataValue(t, db, "log-edge-empty-arr"), "Empty array should be nullified (not a JSON object)")
|
||||
|
||||
// Whitespace JSON should be preserved (Postgres handles it)
|
||||
meta := getMetadataValue(t, db, "log-edge-whitespace")
|
||||
assert.NotNil(t, meta, "Whitespace JSON object should be preserved")
|
||||
|
||||
// Unicode should be preserved
|
||||
assert.NotNil(t, getMetadataValue(t, db, "log-edge-unicode"), "Unicode JSON object should be preserved")
|
||||
|
||||
// Large numbers and scientific notation should be preserved
|
||||
assert.NotNil(t, getMetadataValue(t, db, "log-edge-large-num"), "Large number JSON object should be preserved")
|
||||
assert.NotNil(t, getMetadataValue(t, db, "log-edge-scientific"), "Scientific notation JSON object should be preserved")
|
||||
|
||||
// Verify the GIN index was created
|
||||
assert.True(t, indexExists(t, db, "idx_logs_metadata_gin"), "GIN index should be created")
|
||||
}
|
||||
618
framework/logstore/payload.go
Normal file
618
framework/logstore/payload.go
Normal file
@@ -0,0 +1,618 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// payloadFields lists the DB column names of large TEXT fields that are
|
||||
// offloaded to object storage in hybrid mode. These fields are never needed
|
||||
// for analytics queries (histograms, search, rankings) — only for individual
|
||||
// log detail views (FindByID).
|
||||
var payloadFields = []string{
|
||||
"input_history",
|
||||
"responses_input_history",
|
||||
"output_message",
|
||||
"responses_output",
|
||||
"embedding_output",
|
||||
"rerank_output",
|
||||
"ocr_input",
|
||||
"ocr_output",
|
||||
"params",
|
||||
"tools",
|
||||
"tool_calls",
|
||||
"speech_input",
|
||||
"transcription_input",
|
||||
"image_generation_input",
|
||||
"image_edit_input",
|
||||
"image_variation_input",
|
||||
"video_generation_input",
|
||||
"speech_output",
|
||||
"transcription_output",
|
||||
"image_generation_output",
|
||||
"list_models_output",
|
||||
"video_generation_output",
|
||||
"video_retrieve_output",
|
||||
"video_download_output",
|
||||
"video_list_output",
|
||||
"video_delete_output",
|
||||
"cache_debug",
|
||||
"token_usage",
|
||||
"error_details",
|
||||
"raw_request",
|
||||
"raw_response",
|
||||
"passthrough_request_body",
|
||||
"passthrough_response_body",
|
||||
"routing_engine_logs",
|
||||
}
|
||||
|
||||
// ExtractPayload reads the serialized TEXT payload fields from a Log into a map.
|
||||
// The map keys are the DB column names.
|
||||
func ExtractPayload(l *Log) map[string]string {
|
||||
m := make(map[string]string, len(payloadFields))
|
||||
m["input_history"] = l.InputHistory
|
||||
m["responses_input_history"] = l.ResponsesInputHistory
|
||||
m["output_message"] = l.OutputMessage
|
||||
m["responses_output"] = l.ResponsesOutput
|
||||
m["embedding_output"] = l.EmbeddingOutput
|
||||
m["rerank_output"] = l.RerankOutput
|
||||
m["ocr_input"] = l.OCRInput
|
||||
m["ocr_output"] = l.OCROutput
|
||||
m["params"] = l.Params
|
||||
m["tools"] = l.Tools
|
||||
m["tool_calls"] = l.ToolCalls
|
||||
m["speech_input"] = l.SpeechInput
|
||||
m["transcription_input"] = l.TranscriptionInput
|
||||
m["image_generation_input"] = l.ImageGenerationInput
|
||||
m["image_edit_input"] = l.ImageEditInput
|
||||
m["image_variation_input"] = l.ImageVariationInput
|
||||
m["video_generation_input"] = l.VideoGenerationInput
|
||||
m["speech_output"] = l.SpeechOutput
|
||||
m["transcription_output"] = l.TranscriptionOutput
|
||||
m["image_generation_output"] = l.ImageGenerationOutput
|
||||
m["list_models_output"] = l.ListModelsOutput
|
||||
m["video_generation_output"] = l.VideoGenerationOutput
|
||||
m["video_retrieve_output"] = l.VideoRetrieveOutput
|
||||
m["video_download_output"] = l.VideoDownloadOutput
|
||||
m["video_list_output"] = l.VideoListOutput
|
||||
m["video_delete_output"] = l.VideoDeleteOutput
|
||||
m["cache_debug"] = l.CacheDebug
|
||||
m["token_usage"] = l.TokenUsage
|
||||
m["error_details"] = l.ErrorDetails
|
||||
m["raw_request"] = l.RawRequest
|
||||
m["raw_response"] = l.RawResponse
|
||||
m["passthrough_request_body"] = l.PassthroughRequestBody
|
||||
m["passthrough_response_body"] = l.PassthroughResponseBody
|
||||
m["routing_engine_logs"] = l.RoutingEngineLogs
|
||||
return m
|
||||
}
|
||||
|
||||
// ClearPayload zeros out both the TEXT payload columns and the Parsed virtual
|
||||
// fields on a Log struct. Clearing the Parsed fields is necessary to prevent
|
||||
// GORM's BeforeCreate/SerializeFields from re-populating TEXT columns.
|
||||
// After calling this, the struct only contains index-weight data suitable
|
||||
// for a lightweight DB INSERT.
|
||||
func ClearPayload(l *Log) {
|
||||
// Clear serialized TEXT columns.
|
||||
l.InputHistory = ""
|
||||
l.ResponsesInputHistory = ""
|
||||
l.OutputMessage = ""
|
||||
l.ResponsesOutput = ""
|
||||
l.EmbeddingOutput = ""
|
||||
l.RerankOutput = ""
|
||||
l.OCRInput = ""
|
||||
l.OCROutput = ""
|
||||
l.Params = ""
|
||||
l.Tools = ""
|
||||
l.ToolCalls = ""
|
||||
l.SpeechInput = ""
|
||||
l.TranscriptionInput = ""
|
||||
l.ImageGenerationInput = ""
|
||||
l.ImageEditInput = ""
|
||||
l.ImageVariationInput = ""
|
||||
l.VideoGenerationInput = ""
|
||||
l.SpeechOutput = ""
|
||||
l.TranscriptionOutput = ""
|
||||
l.ImageGenerationOutput = ""
|
||||
l.ListModelsOutput = ""
|
||||
l.VideoGenerationOutput = ""
|
||||
l.VideoRetrieveOutput = ""
|
||||
l.VideoDownloadOutput = ""
|
||||
l.VideoListOutput = ""
|
||||
l.VideoDeleteOutput = ""
|
||||
l.CacheDebug = ""
|
||||
l.TokenUsage = ""
|
||||
l.ErrorDetails = ""
|
||||
l.RawRequest = ""
|
||||
l.RawResponse = ""
|
||||
l.PassthroughRequestBody = ""
|
||||
l.PassthroughResponseBody = ""
|
||||
l.RoutingEngineLogs = ""
|
||||
|
||||
// Clear Parsed virtual fields so GORM's SerializeFields won't re-serialize them.
|
||||
l.InputHistoryParsed = nil
|
||||
l.ResponsesInputHistoryParsed = nil
|
||||
l.OutputMessageParsed = nil
|
||||
l.ResponsesOutputParsed = nil
|
||||
l.EmbeddingOutputParsed = nil
|
||||
l.RerankOutputParsed = nil
|
||||
l.OCRInputParsed = nil
|
||||
l.OCROutputParsed = nil
|
||||
l.ParamsParsed = nil
|
||||
l.ToolsParsed = nil
|
||||
l.ToolCallsParsed = nil
|
||||
l.SpeechInputParsed = nil
|
||||
l.TranscriptionInputParsed = nil
|
||||
l.ImageGenerationInputParsed = nil
|
||||
l.ImageEditInputParsed = nil
|
||||
l.ImageVariationInputParsed = nil
|
||||
l.VideoGenerationInputParsed = nil
|
||||
l.SpeechOutputParsed = nil
|
||||
l.TranscriptionOutputParsed = nil
|
||||
l.ImageGenerationOutputParsed = nil
|
||||
l.ListModelsOutputParsed = nil
|
||||
l.VideoGenerationOutputParsed = nil
|
||||
l.VideoRetrieveOutputParsed = nil
|
||||
l.VideoDownloadOutputParsed = nil
|
||||
l.VideoListOutputParsed = nil
|
||||
l.VideoDeleteOutputParsed = nil
|
||||
l.CacheDebugParsed = nil
|
||||
l.TokenUsageParsed = nil
|
||||
l.ErrorDetailsParsed = nil
|
||||
}
|
||||
|
||||
// MergePayloadFromJSON takes a JSON payload (as marshaled by MarshalPayload)
|
||||
// and merges the fields back into the Log struct's serialized TEXT columns,
|
||||
// then calls DeserializeFields to populate the Parsed virtual fields.
|
||||
func MergePayloadFromJSON(l *Log, data []byte) error {
|
||||
var m map[string]string
|
||||
if err := sonic.Unmarshal(data, &m); err != nil {
|
||||
return fmt.Errorf("logstore: unmarshal payload: %w", err)
|
||||
}
|
||||
if v, ok := m["input_history"]; ok && v != "" {
|
||||
l.InputHistory = v
|
||||
}
|
||||
if v, ok := m["responses_input_history"]; ok && v != "" {
|
||||
l.ResponsesInputHistory = v
|
||||
}
|
||||
if v, ok := m["output_message"]; ok && v != "" {
|
||||
l.OutputMessage = v
|
||||
}
|
||||
if v, ok := m["responses_output"]; ok && v != "" {
|
||||
l.ResponsesOutput = v
|
||||
}
|
||||
if v, ok := m["embedding_output"]; ok && v != "" {
|
||||
l.EmbeddingOutput = v
|
||||
}
|
||||
if v, ok := m["rerank_output"]; ok && v != "" {
|
||||
l.RerankOutput = v
|
||||
}
|
||||
if v, ok := m["ocr_input"]; ok && v != "" {
|
||||
l.OCRInput = v
|
||||
}
|
||||
if v, ok := m["ocr_output"]; ok && v != "" {
|
||||
l.OCROutput = v
|
||||
}
|
||||
if v, ok := m["params"]; ok && v != "" {
|
||||
l.Params = v
|
||||
}
|
||||
if v, ok := m["tools"]; ok && v != "" {
|
||||
l.Tools = v
|
||||
}
|
||||
if v, ok := m["tool_calls"]; ok && v != "" {
|
||||
l.ToolCalls = v
|
||||
}
|
||||
if v, ok := m["speech_input"]; ok && v != "" {
|
||||
l.SpeechInput = v
|
||||
}
|
||||
if v, ok := m["transcription_input"]; ok && v != "" {
|
||||
l.TranscriptionInput = v
|
||||
}
|
||||
if v, ok := m["image_generation_input"]; ok && v != "" {
|
||||
l.ImageGenerationInput = v
|
||||
}
|
||||
if v, ok := m["image_edit_input"]; ok && v != "" {
|
||||
l.ImageEditInput = v
|
||||
}
|
||||
if v, ok := m["image_variation_input"]; ok && v != "" {
|
||||
l.ImageVariationInput = v
|
||||
}
|
||||
if v, ok := m["video_generation_input"]; ok && v != "" {
|
||||
l.VideoGenerationInput = v
|
||||
}
|
||||
if v, ok := m["speech_output"]; ok && v != "" {
|
||||
l.SpeechOutput = v
|
||||
}
|
||||
if v, ok := m["transcription_output"]; ok && v != "" {
|
||||
l.TranscriptionOutput = v
|
||||
}
|
||||
if v, ok := m["image_generation_output"]; ok && v != "" {
|
||||
l.ImageGenerationOutput = v
|
||||
}
|
||||
if v, ok := m["list_models_output"]; ok && v != "" {
|
||||
l.ListModelsOutput = v
|
||||
}
|
||||
if v, ok := m["video_generation_output"]; ok && v != "" {
|
||||
l.VideoGenerationOutput = v
|
||||
}
|
||||
if v, ok := m["video_retrieve_output"]; ok && v != "" {
|
||||
l.VideoRetrieveOutput = v
|
||||
}
|
||||
if v, ok := m["video_download_output"]; ok && v != "" {
|
||||
l.VideoDownloadOutput = v
|
||||
}
|
||||
if v, ok := m["video_list_output"]; ok && v != "" {
|
||||
l.VideoListOutput = v
|
||||
}
|
||||
if v, ok := m["video_delete_output"]; ok && v != "" {
|
||||
l.VideoDeleteOutput = v
|
||||
}
|
||||
if v, ok := m["cache_debug"]; ok && v != "" {
|
||||
l.CacheDebug = v
|
||||
}
|
||||
if v, ok := m["token_usage"]; ok && v != "" {
|
||||
l.TokenUsage = v
|
||||
}
|
||||
if v, ok := m["error_details"]; ok && v != "" {
|
||||
l.ErrorDetails = v
|
||||
}
|
||||
if v, ok := m["raw_request"]; ok && v != "" {
|
||||
l.RawRequest = v
|
||||
}
|
||||
if v, ok := m["raw_response"]; ok && v != "" {
|
||||
l.RawResponse = v
|
||||
}
|
||||
if v, ok := m["passthrough_request_body"]; ok && v != "" {
|
||||
l.PassthroughRequestBody = v
|
||||
}
|
||||
if v, ok := m["passthrough_response_body"]; ok && v != "" {
|
||||
l.PassthroughResponseBody = v
|
||||
}
|
||||
if v, ok := m["routing_engine_logs"]; ok && v != "" {
|
||||
l.RoutingEngineLogs = v
|
||||
}
|
||||
return l.DeserializeFields()
|
||||
}
|
||||
|
||||
// MarshalPayload serializes the payload map (from ExtractPayload) to JSON.
|
||||
func MarshalPayload(payload map[string]string) ([]byte, error) {
|
||||
return sonic.Marshal(payload)
|
||||
}
|
||||
|
||||
// BuildInputContentSummary extracts the last user message text from input fields.
|
||||
// This is used in hybrid mode for the content_summary column, which powers
|
||||
// full-text search and serves as a display fallback in the log list table.
|
||||
// Only the last message is kept — the full conversation history lives in
|
||||
// object storage and is merged back on FindByID.
|
||||
func (l *Log) BuildInputContentSummary() string {
|
||||
// Chat completions: last user message
|
||||
if idx := findLastUserMessageIndex(l.InputHistoryParsed); idx >= 0 {
|
||||
if text := extractChatMessageText(&l.InputHistoryParsed[idx]); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
// Responses API: last user message
|
||||
for i := len(l.ResponsesInputHistoryParsed) - 1; i >= 0; i-- {
|
||||
if l.ResponsesInputHistoryParsed[i].Role != nil && *l.ResponsesInputHistoryParsed[i].Role == schemas.ResponsesInputMessageRoleUser {
|
||||
if text := extractResponsesMessageText(&l.ResponsesInputHistoryParsed[i]); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Speech input
|
||||
if l.SpeechInputParsed != nil && l.SpeechInputParsed.Input != "" {
|
||||
return l.SpeechInputParsed.Input
|
||||
}
|
||||
|
||||
// Image generation input prompt
|
||||
if l.ImageGenerationInputParsed != nil && l.ImageGenerationInputParsed.Prompt != "" {
|
||||
return l.ImageGenerationInputParsed.Prompt
|
||||
}
|
||||
|
||||
// Image edit input prompt
|
||||
if l.ImageEditInputParsed != nil && l.ImageEditInputParsed.Prompt != "" {
|
||||
return l.ImageEditInputParsed.Prompt
|
||||
}
|
||||
|
||||
// Video generation input prompt
|
||||
if l.VideoGenerationInputParsed != nil && l.VideoGenerationInputParsed.Prompt != "" {
|
||||
return l.VideoGenerationInputParsed.Prompt
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractChatMessageText returns the text content from a ChatMessage.
|
||||
// It prefers ContentStr; falls back to the last text ContentBlock.
|
||||
func extractChatMessageText(msg *schemas.ChatMessage) string {
|
||||
if msg.Content == nil {
|
||||
return ""
|
||||
}
|
||||
if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" {
|
||||
return *msg.Content.ContentStr
|
||||
}
|
||||
if msg.Content.ContentBlocks != nil {
|
||||
var lastText string
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil && *block.Text != "" {
|
||||
lastText = *block.Text
|
||||
}
|
||||
}
|
||||
return lastText
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractResponsesMessageText returns the text content from a ResponsesMessage.
|
||||
// It prefers ContentStr; falls back to the last text ContentBlock.
|
||||
func extractResponsesMessageText(msg *schemas.ResponsesMessage) string {
|
||||
if msg.Content == nil {
|
||||
return ""
|
||||
}
|
||||
if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" {
|
||||
return *msg.Content.ContentStr
|
||||
}
|
||||
if msg.Content.ContentBlocks != nil {
|
||||
var lastText string
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil && *block.Text != "" {
|
||||
lastText = *block.Text
|
||||
}
|
||||
}
|
||||
return lastText
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// findLastUserMessageIndex returns the index of the last ChatMessage with
|
||||
// role "user", or -1 if none exists. Used by both BuildInputContentSummary
|
||||
// and prepareDBEntry to avoid scanning the slice twice.
|
||||
func findLastUserMessageIndex(msgs []schemas.ChatMessage) int {
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
if msgs[i].Role == schemas.ChatMessageRoleUser {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// BuildTags creates the S3 object tag map from a Log's index fields.
|
||||
// S3 allows max 10 tags per object; chosen for lifecycle rules and
|
||||
// S3 Metadata Tables queryability.
|
||||
func BuildTags(l *Log) map[string]string {
|
||||
tags := make(map[string]string, 10)
|
||||
if l.Provider != "" {
|
||||
tags["provider"] = l.Provider
|
||||
}
|
||||
if l.Model != "" {
|
||||
tags["model"] = truncateTag(l.Model, 256)
|
||||
}
|
||||
if l.Status != "" {
|
||||
tags["status"] = l.Status
|
||||
}
|
||||
if l.Object != "" {
|
||||
tags["object_type"] = l.Object
|
||||
}
|
||||
if l.VirtualKeyID != nil && *l.VirtualKeyID != "" {
|
||||
tags["virtual_key_id"] = truncateTag(*l.VirtualKeyID, 256)
|
||||
}
|
||||
if l.SelectedKeyID != "" {
|
||||
tags["selected_key_id"] = truncateTag(l.SelectedKeyID, 256)
|
||||
}
|
||||
if l.RoutingRuleID != nil && *l.RoutingRuleID != "" {
|
||||
tags["routing_rule_id"] = truncateTag(*l.RoutingRuleID, 256)
|
||||
}
|
||||
if l.Stream {
|
||||
tags["stream"] = "true"
|
||||
} else {
|
||||
tags["stream"] = "false"
|
||||
}
|
||||
tags["has_error"] = "false"
|
||||
if l.Status == "error" {
|
||||
tags["has_error"] = "true"
|
||||
}
|
||||
tags["date"] = l.Timestamp.UTC().Format("2006-01-02")
|
||||
return tags
|
||||
}
|
||||
|
||||
// ObjectKey constructs the S3 object key for a log entry.
|
||||
func ObjectKey(prefix string, timestamp time.Time, logID string) string {
|
||||
ts := timestamp.UTC()
|
||||
return fmt.Sprintf("%s/logs/%04d/%02d/%02d/%02d/%s.json.gz",
|
||||
prefix,
|
||||
ts.Year(), ts.Month(), ts.Day(), ts.Hour(),
|
||||
logID,
|
||||
)
|
||||
}
|
||||
|
||||
// PayloadFieldNames returns the list of DB column names that are payload fields.
|
||||
func PayloadFieldNames() []string {
|
||||
cp := make([]string, len(payloadFields))
|
||||
copy(cp, payloadFields)
|
||||
return cp
|
||||
}
|
||||
|
||||
// payloadFieldSet is a set for O(1) lookup of payload field names.
|
||||
var payloadFieldSet = func() map[string]struct{} {
|
||||
s := make(map[string]struct{}, len(payloadFields))
|
||||
for _, f := range payloadFields {
|
||||
s[f] = struct{}{}
|
||||
}
|
||||
return s
|
||||
}()
|
||||
|
||||
// fieldsNeedHydration returns true if any of the requested fields are
|
||||
// payload fields that have been offloaded to object storage.
|
||||
func fieldsNeedHydration(fields []string) bool {
|
||||
if len(fields) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, f := range fields {
|
||||
if _, ok := payloadFieldSet[f]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ensureHydrationFields appends id, timestamp, and has_object to the
|
||||
// projection if not already present, so hydrateLog can function correctly.
|
||||
func ensureHydrationFields(fields []string) []string {
|
||||
required := [3]string{"id", "timestamp", "has_object"}
|
||||
have := make(map[string]struct{}, len(fields))
|
||||
for _, f := range fields {
|
||||
have[f] = struct{}{}
|
||||
}
|
||||
for _, r := range required {
|
||||
if _, ok := have[r]; !ok {
|
||||
fields = append(fields, r)
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// pruneUnrequestedPayloadFields clears payload fields that were not in the
|
||||
// caller's field projection. This ensures hydration doesn't break projection
|
||||
// semantics by populating unrequested fields with large blobs.
|
||||
// A nil/empty requestedFields means "no projection" — everything is kept.
|
||||
func pruneUnrequestedPayloadFields(l *Log, requestedFields []string) {
|
||||
if len(requestedFields) == 0 {
|
||||
return
|
||||
}
|
||||
requested := make(map[string]struct{}, len(requestedFields))
|
||||
for _, f := range requestedFields {
|
||||
requested[f] = struct{}{}
|
||||
}
|
||||
for _, pf := range payloadFields {
|
||||
if _, ok := requested[pf]; !ok {
|
||||
clearPayloadField(l, pf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearPayloadField zeros a single payload field (serialized TEXT column and
|
||||
// its Parsed counterpart, if any) by column name.
|
||||
func clearPayloadField(l *Log, name string) {
|
||||
switch name {
|
||||
case "input_history":
|
||||
l.InputHistory = ""
|
||||
l.InputHistoryParsed = nil
|
||||
case "responses_input_history":
|
||||
l.ResponsesInputHistory = ""
|
||||
l.ResponsesInputHistoryParsed = nil
|
||||
case "output_message":
|
||||
l.OutputMessage = ""
|
||||
l.OutputMessageParsed = nil
|
||||
case "responses_output":
|
||||
l.ResponsesOutput = ""
|
||||
l.ResponsesOutputParsed = nil
|
||||
case "embedding_output":
|
||||
l.EmbeddingOutput = ""
|
||||
l.EmbeddingOutputParsed = nil
|
||||
case "rerank_output":
|
||||
l.RerankOutput = ""
|
||||
l.RerankOutputParsed = nil
|
||||
case "ocr_input":
|
||||
l.OCRInput = ""
|
||||
l.OCRInputParsed = nil
|
||||
case "ocr_output":
|
||||
l.OCROutput = ""
|
||||
l.OCROutputParsed = nil
|
||||
case "params":
|
||||
l.Params = ""
|
||||
l.ParamsParsed = nil
|
||||
case "tools":
|
||||
l.Tools = ""
|
||||
l.ToolsParsed = nil
|
||||
case "tool_calls":
|
||||
l.ToolCalls = ""
|
||||
l.ToolCallsParsed = nil
|
||||
case "speech_input":
|
||||
l.SpeechInput = ""
|
||||
l.SpeechInputParsed = nil
|
||||
case "transcription_input":
|
||||
l.TranscriptionInput = ""
|
||||
l.TranscriptionInputParsed = nil
|
||||
case "image_generation_input":
|
||||
l.ImageGenerationInput = ""
|
||||
l.ImageGenerationInputParsed = nil
|
||||
case "image_edit_input":
|
||||
l.ImageEditInput = ""
|
||||
l.ImageEditInputParsed = nil
|
||||
case "image_variation_input":
|
||||
l.ImageVariationInput = ""
|
||||
l.ImageVariationInputParsed = nil
|
||||
case "video_generation_input":
|
||||
l.VideoGenerationInput = ""
|
||||
l.VideoGenerationInputParsed = nil
|
||||
case "speech_output":
|
||||
l.SpeechOutput = ""
|
||||
l.SpeechOutputParsed = nil
|
||||
case "transcription_output":
|
||||
l.TranscriptionOutput = ""
|
||||
l.TranscriptionOutputParsed = nil
|
||||
case "image_generation_output":
|
||||
l.ImageGenerationOutput = ""
|
||||
l.ImageGenerationOutputParsed = nil
|
||||
case "list_models_output":
|
||||
l.ListModelsOutput = ""
|
||||
l.ListModelsOutputParsed = nil
|
||||
case "video_generation_output":
|
||||
l.VideoGenerationOutput = ""
|
||||
l.VideoGenerationOutputParsed = nil
|
||||
case "video_retrieve_output":
|
||||
l.VideoRetrieveOutput = ""
|
||||
l.VideoRetrieveOutputParsed = nil
|
||||
case "video_download_output":
|
||||
l.VideoDownloadOutput = ""
|
||||
l.VideoDownloadOutputParsed = nil
|
||||
case "video_list_output":
|
||||
l.VideoListOutput = ""
|
||||
l.VideoListOutputParsed = nil
|
||||
case "video_delete_output":
|
||||
l.VideoDeleteOutput = ""
|
||||
l.VideoDeleteOutputParsed = nil
|
||||
case "cache_debug":
|
||||
l.CacheDebug = ""
|
||||
l.CacheDebugParsed = nil
|
||||
case "token_usage":
|
||||
l.TokenUsage = ""
|
||||
l.TokenUsageParsed = nil
|
||||
case "error_details":
|
||||
l.ErrorDetails = ""
|
||||
l.ErrorDetailsParsed = nil
|
||||
case "raw_request":
|
||||
l.RawRequest = ""
|
||||
case "raw_response":
|
||||
l.RawResponse = ""
|
||||
case "passthrough_request_body":
|
||||
l.PassthroughRequestBody = ""
|
||||
case "passthrough_response_body":
|
||||
l.PassthroughResponseBody = ""
|
||||
case "routing_engine_logs":
|
||||
l.RoutingEngineLogs = ""
|
||||
}
|
||||
}
|
||||
|
||||
// truncateTag ensures a tag value doesn't exceed the given max length.
|
||||
func truncateTag(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
// Truncate at a rune boundary without exceeding maxLen bytes.
|
||||
byteLen := 0
|
||||
for _, r := range s {
|
||||
rl := utf8.RuneLen(r)
|
||||
if byteLen+rl > maxLen {
|
||||
break
|
||||
}
|
||||
byteLen += rl
|
||||
}
|
||||
return s[:byteLen]
|
||||
}
|
||||
156
framework/logstore/payload_test.go
Normal file
156
framework/logstore/payload_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractPayload_RoundTrip(t *testing.T) {
|
||||
log := &Log{
|
||||
ID: "test-1",
|
||||
InputHistory: `[{"role":"user","content":"hello"}]`,
|
||||
ResponsesInputHistory: `[{"role":"user","content":"hi"}]`,
|
||||
OutputMessage: `{"role":"assistant","content":"world"}`,
|
||||
ResponsesOutput: `[{"role":"assistant","content":"there"}]`,
|
||||
EmbeddingOutput: `[{"embedding":[0.1]}]`,
|
||||
RerankOutput: `[{"score":0.9}]`,
|
||||
Params: `{"temperature":0.7}`,
|
||||
Tools: `[{"name":"tool1"}]`,
|
||||
ToolCalls: `[{"id":"tc1"}]`,
|
||||
SpeechInput: `{"input":"text"}`,
|
||||
TranscriptionInput: `{"file":"test.mp3"}`,
|
||||
ImageGenerationInput: `{"prompt":"cat"}`,
|
||||
ImageEditInput: `{"prompt":"edit cat"}`,
|
||||
ImageVariationInput: `{"image":"base64img"}`,
|
||||
VideoGenerationInput: `{"prompt":"dog"}`,
|
||||
SpeechOutput: `{"audio":"base64"}`,
|
||||
TranscriptionOutput: `{"text":"hello"}`,
|
||||
ImageGenerationOutput: `{"url":"http://img"}`,
|
||||
ListModelsOutput: `[{"id":"model1"}]`,
|
||||
VideoGenerationOutput: `{"id":"vid1"}`,
|
||||
VideoRetrieveOutput: `{"status":"ready"}`,
|
||||
VideoDownloadOutput: `{"url":"http://vid"}`,
|
||||
VideoListOutput: `{"videos":[]}`,
|
||||
VideoDeleteOutput: `{"deleted":true}`,
|
||||
CacheDebug: `{"hit":true}`,
|
||||
TokenUsage: `{"total_tokens":100}`,
|
||||
ErrorDetails: `{"error":"bad"}`,
|
||||
RawRequest: `{"method":"POST"}`,
|
||||
RawResponse: `{"status":200}`,
|
||||
PassthroughRequestBody: `body-req`,
|
||||
PassthroughResponseBody: `body-resp`,
|
||||
RoutingEngineLogs: `routing log`,
|
||||
}
|
||||
|
||||
payload := ExtractPayload(log)
|
||||
assert.Equal(t, len(payloadFields), len(payload), "payload map should have all payload fields")
|
||||
assert.Equal(t, `[{"role":"user","content":"hello"}]`, payload["input_history"])
|
||||
assert.Equal(t, `{"role":"assistant","content":"world"}`, payload["output_message"])
|
||||
assert.Equal(t, `routing log`, payload["routing_engine_logs"])
|
||||
|
||||
// Clear and verify.
|
||||
ClearPayload(log)
|
||||
assert.Empty(t, log.InputHistory)
|
||||
assert.Empty(t, log.OutputMessage)
|
||||
assert.Empty(t, log.RawRequest)
|
||||
assert.Empty(t, log.RoutingEngineLogs)
|
||||
|
||||
// Marshal and merge back.
|
||||
data, err := MarshalPayload(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = MergePayloadFromJSON(log, data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `[{"role":"user","content":"hello"}]`, log.InputHistory)
|
||||
assert.Equal(t, `{"role":"assistant","content":"world"}`, log.OutputMessage)
|
||||
assert.Equal(t, `routing log`, log.RoutingEngineLogs)
|
||||
}
|
||||
|
||||
func TestClearPayload_DoesNotTouchIndexFields(t *testing.T) {
|
||||
log := &Log{
|
||||
ID: "test-1",
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3",
|
||||
Status: "success",
|
||||
InputHistory: `[{"role":"user","content":"hello"}]`,
|
||||
}
|
||||
ClearPayload(log)
|
||||
assert.Equal(t, "test-1", log.ID)
|
||||
assert.Equal(t, "anthropic", log.Provider)
|
||||
assert.Equal(t, "claude-3", log.Model)
|
||||
assert.Equal(t, "success", log.Status)
|
||||
assert.Empty(t, log.InputHistory)
|
||||
}
|
||||
|
||||
func TestBuildInputContentSummary(t *testing.T) {
|
||||
content := "What is the weather?"
|
||||
log := &Log{
|
||||
InputHistoryParsed: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: &content}},
|
||||
},
|
||||
OutputMessageParsed: &schemas.ChatMessage{
|
||||
Content: &schemas.ChatMessageContent{ContentStr: strPtr("It's sunny")},
|
||||
},
|
||||
}
|
||||
|
||||
summary := log.BuildInputContentSummary()
|
||||
assert.Contains(t, summary, "What is the weather?")
|
||||
assert.NotContains(t, summary, "It's sunny", "BuildInputContentSummary should not include output")
|
||||
}
|
||||
|
||||
func TestBuildTags(t *testing.T) {
|
||||
vkID := "vk_123"
|
||||
rrID := "rr_456"
|
||||
log := &Log{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-sonnet",
|
||||
Status: "success",
|
||||
Object: "chat.completion",
|
||||
VirtualKeyID: &vkID,
|
||||
SelectedKeyID: "sk_789",
|
||||
RoutingRuleID: &rrID,
|
||||
Stream: true,
|
||||
Timestamp: time.Date(2026, 4, 3, 14, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
tags := BuildTags(log)
|
||||
assert.Equal(t, "anthropic", tags["provider"])
|
||||
assert.Equal(t, "claude-3-sonnet", tags["model"])
|
||||
assert.Equal(t, "success", tags["status"])
|
||||
assert.Equal(t, "chat.completion", tags["object_type"])
|
||||
assert.Equal(t, "vk_123", tags["virtual_key_id"])
|
||||
assert.Equal(t, "sk_789", tags["selected_key_id"])
|
||||
assert.Equal(t, "rr_456", tags["routing_rule_id"])
|
||||
assert.Equal(t, "true", tags["stream"])
|
||||
assert.Equal(t, "false", tags["has_error"])
|
||||
assert.Equal(t, "2026-04-03", tags["date"])
|
||||
assert.LessOrEqual(t, len(tags), 10, "S3 allows max 10 tags")
|
||||
}
|
||||
|
||||
func TestBuildTags_ErrorStatus(t *testing.T) {
|
||||
log := &Log{Status: "error", Timestamp: time.Now()}
|
||||
tags := BuildTags(log)
|
||||
assert.Equal(t, "true", tags["has_error"])
|
||||
}
|
||||
|
||||
func TestObjectKey(t *testing.T) {
|
||||
ts := time.Date(2026, 4, 3, 14, 0, 0, 0, time.UTC)
|
||||
key := ObjectKey("bifrost", ts, "req_abc123")
|
||||
assert.Equal(t, "bifrost/logs/2026/04/03/14/req_abc123.json.gz", key)
|
||||
}
|
||||
|
||||
func TestPayloadFieldNames(t *testing.T) {
|
||||
fields := PayloadFieldNames()
|
||||
assert.True(t, len(fields) > 0)
|
||||
// Verify it's a copy.
|
||||
fields[0] = "modified"
|
||||
assert.NotEqual(t, "modified", payloadFields[0])
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
189
framework/logstore/postgres.go
Normal file
189
framework/logstore/postgres.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostgresConfig represents the configuration for a Postgres database.
|
||||
type PostgresConfig struct {
|
||||
Host *schemas.EnvVar `json:"host"`
|
||||
Port *schemas.EnvVar `json:"port"`
|
||||
User *schemas.EnvVar `json:"user"`
|
||||
Password *schemas.EnvVar `json:"password"`
|
||||
DBName *schemas.EnvVar `json:"db_name"`
|
||||
SSLMode *schemas.EnvVar `json:"ssl_mode"`
|
||||
MaxIdleConns int `json:"max_idle_conns"`
|
||||
MaxOpenConns int `json:"max_open_conns"`
|
||||
}
|
||||
|
||||
// newPostgresLogStore creates a new Postgres log store.
|
||||
//
|
||||
// Uses a two-pool lifecycle to avoid SQLSTATE 0A000 ("cached plan must not
|
||||
// change result type"): a throwaway pool runs the version check and schema
|
||||
// migrations and is closed immediately, then a fresh runtime pool is opened
|
||||
// for query traffic and the async index / matview builders. The runtime
|
||||
// pool's connections never see pre-migration schema, so their cached
|
||||
// prepared-plans stay valid for the life of the process.
|
||||
func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger schemas.Logger) (LogStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
// Validate required config
|
||||
if config.Host == nil || config.Host.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres host is required")
|
||||
}
|
||||
if config.Port == nil || config.Port.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres port is required")
|
||||
}
|
||||
if config.User == nil || config.User.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres user is required")
|
||||
}
|
||||
if config.Password == nil || config.Password.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres password is required")
|
||||
}
|
||||
if config.DBName == nil || config.DBName.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres db name is required")
|
||||
}
|
||||
if config.SSLMode == nil || config.SSLMode.GetValue() == "" {
|
||||
return nil, fmt.Errorf("postgres ssl mode is required")
|
||||
}
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", config.Host.GetValue(), config.Port.GetValue(), config.User.GetValue(), config.Password.GetValue(), config.DBName.GetValue(), config.SSLMode.GetValue())
|
||||
|
||||
openPool := func() (*gorm.DB, error) {
|
||||
return gorm.Open(postgres.New(postgres.Config{DSN: dsn}), &gorm.Config{
|
||||
Logger: newGormLogger(logger),
|
||||
})
|
||||
}
|
||||
|
||||
// closePoolStrict returns the close error so callers can abort startup
|
||||
// when the throwaway migration pool doesn't tear down cleanly — a half-
|
||||
// closed pool weakens the guarantee that no cached plans survive DDL.
|
||||
closePool := func(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// Throwaway pool for the version gate and schema migrations. Closing it
|
||||
// before the runtime pool opens guarantees no cached plan survives DDL.
|
||||
mDb, err := openPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Postgres version gate: refuse to start below 16 (matviews, partitioning,
|
||||
// and some JSON operators we rely on depend on 16+).
|
||||
var pgVersionNum int
|
||||
if err := mDb.Raw("SELECT current_setting('server_version_num')::int").Scan(&pgVersionNum).Error; err != nil {
|
||||
_ = closePool(mDb)
|
||||
return nil, err
|
||||
}
|
||||
if pgVersionNum < 160000 {
|
||||
_ = closePool(mDb)
|
||||
return nil, fmt.Errorf("postgres version is lower than 16, please upgrade to 16 or higher")
|
||||
}
|
||||
|
||||
if err := triggerMigrations(ctx, mDb); err != nil {
|
||||
_ = closePool(mDb)
|
||||
return nil, err
|
||||
}
|
||||
if err := closePool(mDb); err != nil {
|
||||
return nil, fmt.Errorf("close migration db connection: %w", err)
|
||||
}
|
||||
|
||||
// Runtime pool. Opens against post-migration schema.
|
||||
db, err := openPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
closePool(db)
|
||||
return nil, err
|
||||
}
|
||||
// Set MaxIdleConns (default: 5)
|
||||
maxIdleConns := config.MaxIdleConns
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = 5
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(maxIdleConns)
|
||||
|
||||
// Set MaxOpenConns (default: 50)
|
||||
maxOpenConns := config.MaxOpenConns
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 50
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(maxOpenConns)
|
||||
d := &RDBLogStore{db: db, logger: logger}
|
||||
|
||||
// Run all index builds sequentially in a single goroutine to prevent
|
||||
// deadlocks from concurrent CREATE INDEX CONCURRENTLY on the same table.
|
||||
// Each function is idempotent and acquires its own advisory lock for
|
||||
// cross-node serialization. Running in a goroutine avoids blocking pod startup.
|
||||
go func() {
|
||||
if db.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
// Acquire advisory lock to serialize GIN index builds across cluster nodes.
|
||||
lock, err := acquireIndexLock(context.Background(), db)
|
||||
if err != nil {
|
||||
// Lock is taken by another node, so we will skip the index build
|
||||
return
|
||||
}
|
||||
defer lock.release(context.Background())
|
||||
|
||||
if err := ensureMetadataGINIndex(context.Background(), lock.conn); err != nil {
|
||||
logger.Warn(fmt.Sprintf("logstore: metadata GIN index build failed: %s (queries will still work without the index)", err))
|
||||
} else {
|
||||
logger.Info("logstore: metadata GIN index is ready")
|
||||
}
|
||||
|
||||
if err := ensureDashboardEnhancements(context.Background(), lock.conn); err != nil {
|
||||
logger.Warn(fmt.Sprintf("logstore: dashboard enhancements failed: %s (dashboard will still work with partial data)", err))
|
||||
} else {
|
||||
logger.Info("logstore: dashboard enhancements completed")
|
||||
}
|
||||
|
||||
if err := ensurePerformanceIndexes(context.Background(), lock.conn); err != nil {
|
||||
logger.Warn(fmt.Sprintf("logstore: performance index build failed: %s (queries will still work without the indexes)", err))
|
||||
} else {
|
||||
logger.Info("logstore: performance indexes are ready")
|
||||
}
|
||||
}()
|
||||
|
||||
// Create materialized views and start periodic refresh for dashboard queries.
|
||||
go func() {
|
||||
if db.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
if err := ensureMatViews(context.Background(), db); err != nil {
|
||||
logger.Warn(fmt.Sprintf("logstore: matview creation failed: %s (dashboard queries will use raw tables)", err))
|
||||
return
|
||||
}
|
||||
if err := refreshMatViews(context.Background(), db); err != nil {
|
||||
logger.Warn(fmt.Sprintf("logstore: initial matview refresh failed: %s", err))
|
||||
} else {
|
||||
logger.Info("logstore: materialized views are ready")
|
||||
// Signal that matviews are ready for query use. Until this point,
|
||||
// canUseMatView() returns false so all queries use raw tables.
|
||||
d.matViewsReady.Store(true)
|
||||
}
|
||||
startMatViewRefresher(context.Background(), db, 30*time.Second, logger, &d.matViewsReady)
|
||||
}()
|
||||
|
||||
return d, nil
|
||||
}
|
||||
3545
framework/logstore/rdb.go
Normal file
3545
framework/logstore/rdb.go
Normal file
File diff suppressed because it is too large
Load Diff
260
framework/logstore/rdb_perf_test.go
Normal file
260
framework/logstore/rdb_perf_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
type testLogger struct{}
|
||||
|
||||
func (testLogger) Debug(string, ...any) {}
|
||||
func (testLogger) Info(string, ...any) {}
|
||||
func (testLogger) Warn(string, ...any) {}
|
||||
func (testLogger) Error(string, ...any) {}
|
||||
func (testLogger) Fatal(string, ...any) {}
|
||||
func (testLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (testLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (testLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func newTestSQLiteStore(t *testing.T) *RDBLogStore {
|
||||
t.Helper()
|
||||
|
||||
store, err := newSqliteLogStore(context.Background(), &SQLiteConfig{
|
||||
Path: filepath.Join(t.TempDir(), "logs.db"),
|
||||
}, testLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("newSqliteLogStore() error = %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestLogCreateSerializesFields(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
prompt := "hello"
|
||||
reply := "world"
|
||||
|
||||
entry := &Log{
|
||||
ID: "log-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Object: "chat_completion",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Status: "success",
|
||||
InputHistoryParsed: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &prompt,
|
||||
},
|
||||
}},
|
||||
OutputMessageParsed: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &reply,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := store.Create(context.Background(), entry); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
logEntry, err := store.FindByID(context.Background(), entry.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID() error = %v", err)
|
||||
}
|
||||
if logEntry.InputHistory == "" {
|
||||
t.Fatalf("expected InputHistory to be serialized")
|
||||
}
|
||||
if logEntry.OutputMessage == "" {
|
||||
t.Fatalf("expected OutputMessage to be serialized")
|
||||
}
|
||||
if logEntry.ContentSummary == "" {
|
||||
t.Fatalf("expected ContentSummary to be populated")
|
||||
}
|
||||
if logEntry.CreatedAt.IsZero() {
|
||||
t.Fatalf("expected CreatedAt to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolLogCreateSerializesFields(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
|
||||
entry := &MCPToolLog{
|
||||
ID: "mcp-1",
|
||||
Timestamp: time.Now().UTC(),
|
||||
ToolName: "echo",
|
||||
Status: "success",
|
||||
ArgumentsParsed: map[string]any{
|
||||
"message": "hello",
|
||||
},
|
||||
ResultParsed: map[string]any{
|
||||
"ok": true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := store.CreateMCPToolLog(context.Background(), entry); err != nil {
|
||||
t.Fatalf("CreateMCPToolLog() error = %v", err)
|
||||
}
|
||||
|
||||
logEntry, err := store.FindMCPToolLog(context.Background(), entry.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindMCPToolLog() error = %v", err)
|
||||
}
|
||||
if logEntry.Arguments == "" {
|
||||
t.Fatalf("expected Arguments to be serialized")
|
||||
}
|
||||
if logEntry.Result == "" {
|
||||
t.Fatalf("expected Result to be serialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBulkUpdateCostPostgresSQL(t *testing.T) {
|
||||
updates := map[string]float64{
|
||||
"log-a": 1.25,
|
||||
"log-b": 2.5,
|
||||
}
|
||||
|
||||
query, args := buildBulkUpdateCostPostgresSQL([]string{"log-a", "log-b"}, updates)
|
||||
wantQuery := "UPDATE logs SET cost = v.cost FROM (VALUES ($1::text,$2::float8),($3::text,$4::float8)) AS v(id, cost) WHERE logs.id = v.id"
|
||||
wantArgs := []interface{}{"log-a", 1.25, "log-b", 2.5}
|
||||
|
||||
if query != wantQuery {
|
||||
t.Fatalf("query mismatch\n got: %s\nwant: %s", query, wantQuery)
|
||||
}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Fatalf("args mismatch\n got: %#v\nwant: %#v", args, wantArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSerializesStructEntry(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
now := time.Now().UTC()
|
||||
entry := &Log{
|
||||
ID: "log-update",
|
||||
Timestamp: now,
|
||||
Object: "chat_completion",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Status: "processing",
|
||||
}
|
||||
|
||||
if err := store.Create(context.Background(), entry); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
reply := "updated response"
|
||||
if err := store.Update(context.Background(), entry.ID, Log{
|
||||
Status: "success",
|
||||
OutputMessageParsed: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &reply,
|
||||
},
|
||||
},
|
||||
TokenUsageParsed: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: 3,
|
||||
CompletionTokens: 7,
|
||||
TotalTokens: 10,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
logEntry, err := store.FindByID(context.Background(), entry.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID() error = %v", err)
|
||||
}
|
||||
if logEntry.OutputMessage == "" {
|
||||
t.Fatalf("expected OutputMessage to be serialized on Update")
|
||||
}
|
||||
if logEntry.TokenUsage == "" {
|
||||
t.Fatalf("expected TokenUsage to be serialized on Update")
|
||||
}
|
||||
if logEntry.TotalTokens != 10 {
|
||||
t.Fatalf("expected TotalTokens to be updated, got %d", logEntry.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMCPToolLogSerializesStructEntry(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
now := time.Now().UTC()
|
||||
entry := &MCPToolLog{
|
||||
ID: "mcp-update",
|
||||
Timestamp: now,
|
||||
ToolName: "echo",
|
||||
Status: "processing",
|
||||
}
|
||||
|
||||
if err := store.CreateMCPToolLog(context.Background(), entry); err != nil {
|
||||
t.Fatalf("CreateMCPToolLog() error = %v", err)
|
||||
}
|
||||
|
||||
if err := store.UpdateMCPToolLog(context.Background(), entry.ID, MCPToolLog{
|
||||
Status: "success",
|
||||
ResultParsed: map[string]any{
|
||||
"message": "done",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("UpdateMCPToolLog() error = %v", err)
|
||||
}
|
||||
|
||||
logEntry, err := store.FindMCPToolLog(context.Background(), entry.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("FindMCPToolLog() error = %v", err)
|
||||
}
|
||||
if logEntry.Result == "" {
|
||||
t.Fatalf("expected Result to be serialized on UpdateMCPToolLog")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkUpdateCostSQLiteFallback(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
now := time.Now().UTC()
|
||||
entries := []*Log{
|
||||
{
|
||||
ID: "log-a",
|
||||
Timestamp: now,
|
||||
Object: "chat_completion",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Status: "success",
|
||||
},
|
||||
{
|
||||
ID: "log-b",
|
||||
Timestamp: now,
|
||||
Object: "chat_completion",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Status: "success",
|
||||
},
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if err := store.Create(context.Background(), entry); err != nil {
|
||||
t.Fatalf("Create(%s) error = %v", entry.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := store.BulkUpdateCost(context.Background(), map[string]float64{
|
||||
"log-a": 1.5,
|
||||
"log-b": 2.5,
|
||||
}); err != nil {
|
||||
t.Fatalf("BulkUpdateCost() error = %v", err)
|
||||
}
|
||||
|
||||
for id, wantCost := range map[string]float64{"log-a": 1.5, "log-b": 2.5} {
|
||||
logEntry, err := store.FindByID(context.Background(), id)
|
||||
if err != nil {
|
||||
t.Fatalf("FindByID(%s) error = %v", id, err)
|
||||
}
|
||||
if logEntry.Cost == nil || *logEntry.Cost != wantCost {
|
||||
t.Fatalf("cost mismatch for %s: got %v want %v", id, logEntry.Cost, wantCost)
|
||||
}
|
||||
}
|
||||
}
|
||||
585
framework/logstore/rdb_postgres_perf_test.go
Normal file
585
framework/logstore/rdb_postgres_perf_test.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// setupPerfTestDB connects to Postgres, runs migrations, and returns the store.
|
||||
func setupPerfTestDB(t *testing.T) (*RDBLogStore, *gorm.DB) {
|
||||
t.Helper()
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
// Clean slate — drop test-owned tables but preserve the shared migrations
|
||||
// table so concurrent test packages (e.g. configstore) are not disrupted.
|
||||
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_hourly CASCADE")
|
||||
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_filterdata CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
|
||||
db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)")
|
||||
db.Exec("DELETE FROM migrations")
|
||||
|
||||
ctx := context.Background()
|
||||
err := triggerMigrations(ctx, db)
|
||||
require.NoError(t, err, "migrations should succeed")
|
||||
|
||||
err = ensureMatViews(ctx, db)
|
||||
require.NoError(t, err, "matview creation should succeed")
|
||||
|
||||
store := &RDBLogStore{db: db}
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, idx := range performanceIndexes {
|
||||
db.Exec("DROP INDEX IF EXISTS " + idx.name)
|
||||
}
|
||||
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_hourly CASCADE")
|
||||
db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_filterdata CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
|
||||
db.Exec("DELETE FROM migrations")
|
||||
})
|
||||
|
||||
return store, db
|
||||
}
|
||||
|
||||
// acquirePerfTestSQLConn returns a dedicated connection for ensurePerformanceIndexes (CONCURRENTLY + session SET).
|
||||
func acquirePerfTestSQLConn(t *testing.T, ctx context.Context, db *gorm.DB) *sql.Conn {
|
||||
t.Helper()
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
conn, err := sqlDB.Conn(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
return conn
|
||||
}
|
||||
|
||||
type logOpts struct {
|
||||
Model string
|
||||
Provider string
|
||||
Status string
|
||||
Timestamp time.Time
|
||||
RoutingEnginesUsed string
|
||||
Metadata string
|
||||
ContentSummary string
|
||||
VirtualKeyID string
|
||||
VirtualKeyName string
|
||||
SelectedKeyID string
|
||||
SelectedKeyName string
|
||||
RoutingRuleID string
|
||||
RoutingRuleName string
|
||||
}
|
||||
|
||||
func insertPerfLog(t *testing.T, db *gorm.DB, opts logOpts) {
|
||||
t.Helper()
|
||||
if opts.Provider == "" {
|
||||
opts.Provider = "openai"
|
||||
}
|
||||
if opts.Status == "" {
|
||||
opts.Status = "success"
|
||||
}
|
||||
if opts.Model == "" {
|
||||
opts.Model = "gpt-4"
|
||||
}
|
||||
id := uuid.New().String()
|
||||
err := db.Exec(`
|
||||
INSERT INTO logs (id, timestamp, object_type, provider, model, status,
|
||||
routing_engines_used, metadata, content_summary,
|
||||
virtual_key_id, virtual_key_name, selected_key_id, selected_key_name,
|
||||
routing_rule_id, routing_rule_name, created_at, latency, cost,
|
||||
prompt_tokens, completion_tokens, total_tokens)
|
||||
VALUES (?, ?, 'chat_completion', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 100, 0.01, 10, 5, 15)
|
||||
`, id, opts.Timestamp, opts.Provider, opts.Model, opts.Status,
|
||||
opts.RoutingEnginesUsed, opts.Metadata, opts.ContentSummary,
|
||||
opts.VirtualKeyID, opts.VirtualKeyName, opts.SelectedKeyID, opts.SelectedKeyName,
|
||||
opts.RoutingRuleID, opts.RoutingRuleName, opts.Timestamp).Error
|
||||
require.NoError(t, err, "Failed to insert test log")
|
||||
}
|
||||
|
||||
type mcpLogOpts struct {
|
||||
ToolName string
|
||||
ServerLabel string
|
||||
Timestamp time.Time
|
||||
VirtualKeyID string
|
||||
VirtualKeyName string
|
||||
Arguments string
|
||||
Result string
|
||||
}
|
||||
|
||||
func insertPerfMCPLog(t *testing.T, db *gorm.DB, opts mcpLogOpts) {
|
||||
t.Helper()
|
||||
id := uuid.New().String()
|
||||
err := db.Exec(`
|
||||
INSERT INTO mcp_tool_logs (id, llm_request_id, tool_name, server_label,
|
||||
timestamp, status, latency, cost,
|
||||
virtual_key_id, virtual_key_name, arguments, result, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, 'success', 50, 0.001, ?, ?, ?, ?, ?)
|
||||
`, id, uuid.New().String(), opts.ToolName, opts.ServerLabel,
|
||||
opts.Timestamp, opts.VirtualKeyID, opts.VirtualKeyName,
|
||||
opts.Arguments, opts.Result, opts.Timestamp).Error
|
||||
require.NoError(t, err, "Failed to insert MCP test log")
|
||||
}
|
||||
|
||||
// refreshTestMatViews refreshes materialized views after inserting test data.
|
||||
// This is needed because matviews are populated at creation time and don't
|
||||
// automatically reflect new inserts until explicitly refreshed.
|
||||
func refreshTestMatViews(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
err := refreshMatViews(ctx, db)
|
||||
require.NoError(t, err, "Failed to refresh materialized views")
|
||||
}
|
||||
|
||||
// ---------- Phase 1: Defensive Limits ----------
|
||||
|
||||
func TestSearchLogs_LimitClamping(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
insertPerfLog(t, db, logOpts{Timestamp: now})
|
||||
}
|
||||
refreshTestMatViews(t, db)
|
||||
|
||||
// Limit=0 should be clamped (not return 0 results)
|
||||
result, err := store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: 0})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, len(result.Logs), "Limit=0 should be clamped")
|
||||
|
||||
// Limit=2 should return 2
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: 2})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(result.Logs))
|
||||
|
||||
// Limit=-1 should be clamped
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: -1})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, len(result.Logs), "Limit=-1 should be clamped")
|
||||
|
||||
// Limit=2000 should be clamped to 1000
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{}, PaginationOptions{Limit: 2000})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, len(result.Logs))
|
||||
}
|
||||
|
||||
func TestSearchMCPToolLogs_LimitClamping(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "search", ServerLabel: "s1", Timestamp: now,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "key-1",
|
||||
})
|
||||
}
|
||||
|
||||
result, err := store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{}, PaginationOptions{Limit: 0})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, len(result.Logs), "Limit=0 should be clamped")
|
||||
|
||||
result, err = store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{}, PaginationOptions{Limit: 3})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, len(result.Logs))
|
||||
}
|
||||
|
||||
func TestGetModelRankings_HasLimit(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
start := now.Add(-1 * time.Hour)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Model: fmt.Sprintf("model-%d", i), Timestamp: now,
|
||||
})
|
||||
}
|
||||
refreshTestMatViews(t, db)
|
||||
|
||||
result, err := store.GetModelRankings(ctx, SearchFilters{StartTime: &start, EndTime: &now})
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(result.Rankings), defaultMaxRankingsLimit)
|
||||
assert.Equal(t, 5, len(result.Rankings))
|
||||
}
|
||||
|
||||
func TestDeleteExpiredAsyncJobs_BatchDeletes(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
past := time.Now().UTC().Add(-1 * time.Hour)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := db.Exec(`
|
||||
INSERT INTO async_jobs (id, status, request_type, virtual_key_id, expires_at, created_at)
|
||||
VALUES (?, 'completed', 'chat_completion', 'vk-1', ?, ?)
|
||||
`, uuid.New().String(), past, past).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
deleted, err := store.DeleteExpiredAsyncJobs(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), deleted)
|
||||
|
||||
var count int64
|
||||
db.Model(&AsyncJob{}).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
// ---------- Phase 2: Time-scoped filter data ----------
|
||||
|
||||
func TestGetDistinctModels_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfLog(t, db, logOpts{Model: "recent-model", Timestamp: recent})
|
||||
insertPerfLog(t, db, logOpts{Model: "old-model", Timestamp: old})
|
||||
refreshTestMatViews(t, db)
|
||||
|
||||
models, err := store.GetDistinctModels(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, models, "recent-model")
|
||||
assert.NotContains(t, models, "old-model")
|
||||
}
|
||||
|
||||
func TestGetDistinctKeyPairs_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: recent, VirtualKeyID: "vk-recent", VirtualKeyName: "Recent Key",
|
||||
})
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: old, VirtualKeyID: "vk-old", VirtualKeyName: "Old Key",
|
||||
})
|
||||
refreshTestMatViews(t, db)
|
||||
|
||||
pairs, err := store.GetDistinctKeyPairs(ctx, "virtual_key_id", "virtual_key_name")
|
||||
require.NoError(t, err)
|
||||
|
||||
var ids []string
|
||||
for _, p := range pairs {
|
||||
ids = append(ids, p.ID)
|
||||
}
|
||||
assert.Contains(t, ids, "vk-recent")
|
||||
assert.NotContains(t, ids, "vk-old")
|
||||
}
|
||||
|
||||
func TestGetDistinctRoutingEngines_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: recent, RoutingEnginesUsed: "loadbalancing,governance",
|
||||
})
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: old, RoutingEnginesUsed: "routing-rule",
|
||||
})
|
||||
refreshTestMatViews(t, db)
|
||||
|
||||
engines, err := store.GetDistinctRoutingEngines(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, engines, "loadbalancing")
|
||||
assert.Contains(t, engines, "governance")
|
||||
assert.NotContains(t, engines, "routing-rule")
|
||||
}
|
||||
|
||||
func TestGetDistinctMetadataKeys_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: recent, Metadata: `{"env": "production"}`,
|
||||
})
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: old, Metadata: `{"old_key": "old_value"}`,
|
||||
})
|
||||
|
||||
keys, err := store.GetDistinctMetadataKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, keys, "env")
|
||||
assert.NotContains(t, keys, "old_key")
|
||||
}
|
||||
|
||||
func TestGetAvailableToolNames_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "recent-tool", ServerLabel: "s1", Timestamp: recent,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
|
||||
})
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "old-tool", ServerLabel: "s1", Timestamp: old,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
|
||||
})
|
||||
|
||||
tools, err := store.GetAvailableToolNames(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, tools, "recent-tool")
|
||||
assert.NotContains(t, tools, "old-tool")
|
||||
}
|
||||
|
||||
func TestGetAvailableServerLabels_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "t1", ServerLabel: "recent-server", Timestamp: recent,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
|
||||
})
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "t2", ServerLabel: "old-server", Timestamp: old,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
|
||||
})
|
||||
|
||||
labels, err := store.GetAvailableServerLabels(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, labels, "recent-server")
|
||||
assert.NotContains(t, labels, "old-server")
|
||||
}
|
||||
|
||||
func TestGetAvailableMCPVirtualKeys_TimeCutoff(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
recent := now.Add(-7 * 24 * time.Hour)
|
||||
old := now.Add(-60 * 24 * time.Hour)
|
||||
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "t1", ServerLabel: "s1", Timestamp: recent,
|
||||
VirtualKeyID: "vk-recent", VirtualKeyName: "Recent VK",
|
||||
})
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "t2", ServerLabel: "s1", Timestamp: old,
|
||||
VirtualKeyID: "vk-old", VirtualKeyName: "Old VK",
|
||||
})
|
||||
|
||||
keys, err := store.GetAvailableMCPVirtualKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var ids []string
|
||||
for _, k := range keys {
|
||||
if k.VirtualKeyID != nil {
|
||||
ids = append(ids, *k.VirtualKeyID)
|
||||
}
|
||||
}
|
||||
assert.Contains(t, ids, "vk-recent")
|
||||
assert.NotContains(t, ids, "vk-old")
|
||||
}
|
||||
|
||||
// ---------- Phase 3: Routing engine filter + indexes ----------
|
||||
|
||||
func TestRoutingEngineFilter_Postgres(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
start := now.Add(-1 * time.Hour)
|
||||
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Model: "m1", Timestamp: now, RoutingEnginesUsed: "loadbalancing,governance",
|
||||
})
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Model: "m2", Timestamp: now, RoutingEnginesUsed: "routing-rule",
|
||||
})
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Model: "m3", Timestamp: now, RoutingEnginesUsed: "loadbalancing",
|
||||
})
|
||||
|
||||
// Single engine filter
|
||||
result, err := store.SearchLogs(ctx, SearchFilters{
|
||||
RoutingEngineUsed: []string{"loadbalancing"},
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(result.Logs), "Should find 2 logs with loadbalancing")
|
||||
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{
|
||||
RoutingEngineUsed: []string{"governance"},
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(result.Logs), "Should find 1 log with governance")
|
||||
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{
|
||||
RoutingEngineUsed: []string{"routing-rule"},
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(result.Logs), "Should find 1 log with routing-rule")
|
||||
|
||||
// Multiple engines (OR)
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{
|
||||
RoutingEngineUsed: []string{"loadbalancing", "routing-rule"},
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, len(result.Logs), "Should find all 3 with loadbalancing OR routing-rule")
|
||||
|
||||
// Non-existent engine
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{
|
||||
RoutingEngineUsed: []string{"nonexistent"},
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(result.Logs))
|
||||
}
|
||||
|
||||
func TestEnsurePerformanceIndexes(t *testing.T) {
|
||||
db := trySetupPostgresDB(t)
|
||||
if db == nil {
|
||||
t.Skip("Postgres not available, skipping test")
|
||||
}
|
||||
|
||||
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
|
||||
db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)")
|
||||
db.Exec("DELETE FROM migrations")
|
||||
|
||||
ctx := context.Background()
|
||||
err := triggerMigrations(ctx, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, idx := range performanceIndexes {
|
||||
db.Exec("DROP INDEX IF EXISTS " + idx.name)
|
||||
}
|
||||
db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE")
|
||||
db.Exec("DROP TABLE IF EXISTS logs CASCADE")
|
||||
db.Exec("DELETE FROM migrations")
|
||||
})
|
||||
|
||||
conn := acquirePerfTestSQLConn(t, ctx, db)
|
||||
// First run
|
||||
err = ensurePerformanceIndexes(ctx, conn)
|
||||
require.NoError(t, err, "ensurePerformanceIndexes should succeed")
|
||||
|
||||
// Verify all indexes exist and are valid
|
||||
for _, idx := range performanceIndexes {
|
||||
var indexValid bool
|
||||
err := db.Raw(`
|
||||
SELECT COALESCE(bool_and(pi.indisvalid), false)
|
||||
FROM pg_class pc
|
||||
JOIN pg_index pi ON pi.indrelid = pc.oid
|
||||
JOIN pg_class ic ON ic.oid = pi.indexrelid
|
||||
WHERE pc.relname = ?
|
||||
AND ic.relname = ?
|
||||
`, idx.table, idx.name).Scan(&indexValid).Error
|
||||
require.NoError(t, err)
|
||||
assert.True(t, indexValid, "Index %s should be valid", idx.name)
|
||||
}
|
||||
|
||||
// Idempotent — second run should be a no-op
|
||||
err = ensurePerformanceIndexes(ctx, conn)
|
||||
require.NoError(t, err, "ensurePerformanceIndexes should be idempotent")
|
||||
}
|
||||
|
||||
func TestContentSearch_Postgres(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
start := now.Add(-1 * time.Hour)
|
||||
|
||||
// Build indexes
|
||||
conn := acquirePerfTestSQLConn(t, ctx, db)
|
||||
|
||||
err := ensurePerformanceIndexes(ctx, conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: now,
|
||||
ContentSummary: "The quick brown fox jumps over the lazy dog",
|
||||
})
|
||||
insertPerfLog(t, db, logOpts{
|
||||
Timestamp: now,
|
||||
ContentSummary: "Hello world this is a test message",
|
||||
})
|
||||
|
||||
result, err := store.SearchLogs(ctx, SearchFilters{
|
||||
ContentSearch: "brown fox",
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(result.Logs), "Should find 1 log matching 'brown fox'")
|
||||
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{
|
||||
ContentSearch: "test message",
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(result.Logs), "Should find 1 log matching 'test message'")
|
||||
|
||||
result, err = store.SearchLogs(ctx, SearchFilters{
|
||||
ContentSearch: "nonexistent phrase",
|
||||
StartTime: &start, EndTime: &now,
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(result.Logs))
|
||||
}
|
||||
|
||||
func TestMCPContentSearch_Postgres(t *testing.T) {
|
||||
store, db := setupPerfTestDB(t)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Build indexes
|
||||
conn := acquirePerfTestSQLConn(t, ctx, db)
|
||||
err := ensurePerformanceIndexes(ctx, conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "search", ServerLabel: "s1", Timestamp: now,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
|
||||
Arguments: `{"query": "weather in london"}`,
|
||||
Result: `{"temperature": 15}`,
|
||||
})
|
||||
insertPerfMCPLog(t, db, mcpLogOpts{
|
||||
ToolName: "calc", ServerLabel: "s1", Timestamp: now,
|
||||
VirtualKeyID: "vk-1", VirtualKeyName: "k1",
|
||||
Arguments: `{"expression": "2+2"}`,
|
||||
Result: `{"answer": 4}`,
|
||||
})
|
||||
|
||||
result, err := store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{
|
||||
ContentSearch: "london",
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(result.Logs), "Should find 1 MCP log matching 'london'")
|
||||
|
||||
result, err = store.SearchMCPToolLogs(ctx, MCPToolLogSearchFilters{
|
||||
ContentSearch: "temperature",
|
||||
}, PaginationOptions{Limit: 100})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(result.Logs), "Should find 1 MCP log matching 'temperature' in result")
|
||||
}
|
||||
47
framework/logstore/sqlite.go
Normal file
47
framework/logstore/sqlite.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SQLiteConfig represents the configuration for a SQLite database.
|
||||
type SQLiteConfig struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// newSqliteLogStore creates a new SQLite log store.
|
||||
func newSqliteLogStore(ctx context.Context, config *SQLiteConfig, logger schemas.Logger) (*RDBLogStore, error) {
|
||||
if _, err := os.Stat(config.Path); os.IsNotExist(err) {
|
||||
// Create DB file
|
||||
f, err := os.Create(config.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = f.Close()
|
||||
}
|
||||
// Configure SQLite with proper settings to handle concurrent access
|
||||
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_cache_size=10000&_busy_timeout=60000&_wal_autocheckpoint=1000&_foreign_keys=1", config.Path)
|
||||
logger.Debug("opening DB with dsn: %s", dsn)
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||
Logger: newGormLogger(logger),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("db opened for logstore")
|
||||
|
||||
s := &RDBLogStore{db: db, logger: logger}
|
||||
// Run migrations
|
||||
if err := triggerMigrations(ctx, db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
140
framework/logstore/store.go
Normal file
140
framework/logstore/store.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/objectstore"
|
||||
)
|
||||
|
||||
// LogStoreType represents the type of log store.
|
||||
type LogStoreType string
|
||||
|
||||
// LogStoreTypeSQLite is the type of log store for SQLite.
|
||||
const (
|
||||
LogStoreTypeSQLite LogStoreType = "sqlite"
|
||||
LogStoreTypePostgres LogStoreType = "postgres"
|
||||
)
|
||||
|
||||
// LogStore is the interface for the log store.
|
||||
type LogStore interface {
|
||||
Ping(ctx context.Context) error
|
||||
Create(ctx context.Context, entry *Log) error
|
||||
CreateIfNotExists(ctx context.Context, entry *Log) error
|
||||
BatchCreateIfNotExists(ctx context.Context, entries []*Log) error
|
||||
FindByID(ctx context.Context, id string) (*Log, error)
|
||||
IsLogEntryPresent(ctx context.Context, id string) (bool, error)
|
||||
FindFirst(ctx context.Context, query any, fields ...string) (*Log, error)
|
||||
FindAll(ctx context.Context, query any, fields ...string) ([]*Log, error)
|
||||
FindAllDistinct(ctx context.Context, query any, fields ...string) ([]*Log, error)
|
||||
HasLogs(ctx context.Context) (bool, error)
|
||||
SearchLogs(ctx context.Context, filters SearchFilters, pagination PaginationOptions) (*SearchResult, error)
|
||||
GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error)
|
||||
GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error)
|
||||
GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error)
|
||||
GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error)
|
||||
GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error)
|
||||
GetCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*CostHistogramResult, error)
|
||||
GetModelHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ModelHistogramResult, error)
|
||||
GetLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*LatencyHistogramResult, error)
|
||||
GetProviderCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderCostHistogramResult, error)
|
||||
GetProviderTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderTokenHistogramResult, error)
|
||||
GetProviderLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderLatencyHistogramResult, error)
|
||||
GetModelRankings(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error)
|
||||
GetUserRankings(ctx context.Context, filters SearchFilters) (*UserRankingResult, error)
|
||||
// GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension (e.g., team_id, customer_id).
|
||||
GetDimensionCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error)
|
||||
// GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension.
|
||||
GetDimensionTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error)
|
||||
// GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension.
|
||||
GetDimensionLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error)
|
||||
Update(ctx context.Context, id string, entry any) error
|
||||
BulkUpdateCost(ctx context.Context, updates map[string]float64) error
|
||||
Flush(ctx context.Context, since time.Time) error
|
||||
Close(ctx context.Context) error
|
||||
DeleteLog(ctx context.Context, id string) error
|
||||
DeleteLogs(ctx context.Context, ids []string) error
|
||||
DeleteLogsBatch(ctx context.Context, cutoff time.Time, batchSize int) (deletedCount int64, err error)
|
||||
|
||||
// Distinct value methods for filter data
|
||||
GetDistinctModels(ctx context.Context) ([]string, error)
|
||||
GetDistinctAliases(ctx context.Context) ([]string, error)
|
||||
GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error)
|
||||
GetDistinctRoutingEngines(ctx context.Context) ([]string, error)
|
||||
GetDistinctMetadataKeys(ctx context.Context) (map[string][]string, error)
|
||||
|
||||
// MCP Tool Log histogram methods
|
||||
GetMCPHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPHistogramResult, error)
|
||||
GetMCPCostHistogram(ctx context.Context, filters MCPToolLogSearchFilters, bucketSizeSeconds int64) (*MCPCostHistogramResult, error)
|
||||
GetMCPTopTools(ctx context.Context, filters MCPToolLogSearchFilters, limit int) (*MCPTopToolsResult, error)
|
||||
|
||||
// MCP Tool Log methods
|
||||
CreateMCPToolLog(ctx context.Context, entry *MCPToolLog) error
|
||||
FindMCPToolLog(ctx context.Context, id string) (*MCPToolLog, error)
|
||||
UpdateMCPToolLog(ctx context.Context, id string, entry any) error
|
||||
SearchMCPToolLogs(ctx context.Context, filters MCPToolLogSearchFilters, pagination PaginationOptions) (*MCPToolLogSearchResult, error)
|
||||
GetMCPToolLogStats(ctx context.Context, filters MCPToolLogSearchFilters) (*MCPToolLogStats, error)
|
||||
HasMCPToolLogs(ctx context.Context) (bool, error)
|
||||
DeleteMCPToolLogs(ctx context.Context, ids []string) error
|
||||
FlushMCPToolLogs(ctx context.Context, since time.Time) error
|
||||
GetAvailableToolNames(ctx context.Context) ([]string, error)
|
||||
GetAvailableServerLabels(ctx context.Context) ([]string, error)
|
||||
GetAvailableMCPVirtualKeys(ctx context.Context) ([]MCPToolLog, error)
|
||||
|
||||
// Async Job methods
|
||||
CreateAsyncJob(ctx context.Context, job *AsyncJob) error
|
||||
FindAsyncJobByID(ctx context.Context, id string) (*AsyncJob, error)
|
||||
UpdateAsyncJob(ctx context.Context, id string, updates map[string]interface{}) error
|
||||
DeleteExpiredAsyncJobs(ctx context.Context) (int64, error)
|
||||
DeleteStaleAsyncJobs(ctx context.Context, staleSince time.Time) (int64, error)
|
||||
}
|
||||
|
||||
// NewLogStore creates a new log store based on the configuration.
|
||||
// When ObjectStorage is configured, the returned store is wrapped with a
|
||||
// HybridLogStore that offloads payloads to S3-compatible object storage.
|
||||
func NewLogStore(ctx context.Context, config *Config, logger schemas.Logger) (LogStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("logstore: config is nil")
|
||||
}
|
||||
|
||||
var inner LogStore
|
||||
var err error
|
||||
|
||||
switch config.Type {
|
||||
case LogStoreTypeSQLite:
|
||||
if sqliteConfig, ok := config.Config.(*SQLiteConfig); ok {
|
||||
inner, err = newSqliteLogStore(ctx, sqliteConfig, logger)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid sqlite config: %T", config.Config)
|
||||
}
|
||||
case LogStoreTypePostgres:
|
||||
if postgresConfig, ok := config.Config.(*PostgresConfig); ok {
|
||||
inner, err = newPostgresLogStore(ctx, postgresConfig, logger)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid postgres config: %T", config.Config)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported log store type: %s", config.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Optionally wrap with hybrid decorator for object storage offloading.
|
||||
if config.ObjectStorage != nil {
|
||||
objStore, objErr := objectstore.NewObjectStore(ctx, config.ObjectStorage, logger)
|
||||
if objErr != nil {
|
||||
_ = inner.Close(ctx)
|
||||
return nil, fmt.Errorf("failed to create object store: %w", objErr)
|
||||
}
|
||||
if err := objStore.Ping(ctx); err != nil {
|
||||
_ = objStore.Close()
|
||||
_ = inner.Close(ctx)
|
||||
return nil, fmt.Errorf("failed to ping object store: %w", err)
|
||||
}
|
||||
return newHybridLogStore(inner, objStore, config.ObjectStorage.GetPrefix(), logger), nil
|
||||
}
|
||||
return inner, nil
|
||||
}
|
||||
1480
framework/logstore/tables.go
Normal file
1480
framework/logstore/tables.go
Normal file
File diff suppressed because it is too large
Load Diff
90
framework/mcpcatalog/main.go
Normal file
90
framework/mcpcatalog/main.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package mcpcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"sync"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
type MCPCatalog struct {
|
||||
mu sync.RWMutex
|
||||
pricingData MCPPricingData
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// PricingEntry represents a single MCP server's tool call pricing information
|
||||
type PricingEntry struct {
|
||||
Server string `json:"server"`
|
||||
ToolName string `json:"tool_name"`
|
||||
CostPerExecution float64 `json:"cost_per_execution"`
|
||||
}
|
||||
|
||||
type MCPPricingData map[string]PricingEntry // Map of [{server_label}/{tool_name}] -> PricingEntry
|
||||
|
||||
type Config struct {
|
||||
PricingData MCPPricingData
|
||||
}
|
||||
|
||||
// Init initializes the MCP catalog
|
||||
func Init(ctx context.Context, config *Config, logger schemas.Logger) (*MCPCatalog, error) {
|
||||
logger.Info("initializing MCP catalog...")
|
||||
|
||||
pricingData := MCPPricingData{}
|
||||
|
||||
if config != nil && config.PricingData != nil {
|
||||
// Defensively copy the pricing map to prevent external mutations
|
||||
pricingData = make(MCPPricingData, len(config.PricingData))
|
||||
maps.Copy(pricingData, config.PricingData)
|
||||
}
|
||||
|
||||
return &MCPCatalog{
|
||||
logger: logger,
|
||||
pricingData: pricingData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAllPricingData returns all the pricing data
|
||||
func (mc *MCPCatalog) GetAllPricingData() MCPPricingData {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
// Create a defensive copy to prevent callers from mutating shared state
|
||||
copy := make(MCPPricingData, len(mc.pricingData))
|
||||
maps.Copy(copy, mc.pricingData)
|
||||
return copy
|
||||
}
|
||||
|
||||
// GetPricingData returns the pricing data for the given server and tool name
|
||||
func (mc *MCPCatalog) GetPricingData(server string, toolName string) (PricingEntry, bool) {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
pricing, ok := mc.pricingData[fmt.Sprintf("%s/%s", server, toolName)]
|
||||
return pricing, ok
|
||||
}
|
||||
|
||||
// UpdatePricingData updates the pricing data for the given server and tool name
|
||||
func (mc *MCPCatalog) UpdatePricingData(server string, toolName string, costPerExecution float64) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.pricingData[fmt.Sprintf("%s/%s", server, toolName)] = PricingEntry{
|
||||
Server: server,
|
||||
ToolName: toolName,
|
||||
CostPerExecution: costPerExecution,
|
||||
}
|
||||
}
|
||||
|
||||
// DeletePricingData deletes the pricing data for the given server and tool name
|
||||
func (mc *MCPCatalog) DeletePricingData(server string, toolName string) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
delete(mc.pricingData, fmt.Sprintf("%s/%s", server, toolName))
|
||||
}
|
||||
|
||||
// Cleanup cleans up the MCP catalog
|
||||
func (mc *MCPCatalog) Cleanup() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.pricingData = nil
|
||||
}
|
||||
618
framework/migrator/migrator.go
Normal file
618
framework/migrator/migrator.go
Normal file
@@ -0,0 +1,618 @@
|
||||
// Portions of this file are derived from https://github.com/go-gormigrate/gormigrate
|
||||
// MIT License
|
||||
// Copyright (c) 2016 Andrey Nering
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
initSchemaMigrationID = "SCHEMA_INIT"
|
||||
)
|
||||
|
||||
// MigrateFunc is the func signature for migrating.
|
||||
type MigrateFunc func(*gorm.DB) error
|
||||
|
||||
// RollbackFunc is the func signature for rollbacking.
|
||||
type RollbackFunc func(*gorm.DB) error
|
||||
|
||||
// InitSchemaFunc is the func signature for initializing the schema.
|
||||
type InitSchemaFunc func(*gorm.DB) error
|
||||
|
||||
// Options define options for all migrations.
|
||||
type Options struct {
|
||||
// TableName is the migration table.
|
||||
TableName string
|
||||
// IDColumnName is the name of column where the migration id will be stored.
|
||||
IDColumnName string
|
||||
// IDColumnSize is the length of the migration id column
|
||||
IDColumnSize int
|
||||
// SequenceColumnName is the name of the auto-incrementing numeric column.
|
||||
SequenceColumnName string
|
||||
// AppliedAtColumnName is the name of the column storing when the migration was applied.
|
||||
AppliedAtColumnName string
|
||||
// StatusColumnName is the name of the column storing the migration status (success/failure).
|
||||
StatusColumnName string
|
||||
// UseTransaction makes Gormigrate execute migrations inside a single transaction.
|
||||
// Keep in mind that not all databases support DDL commands inside transactions.
|
||||
UseTransaction bool
|
||||
// ValidateUnknownMigrations will cause migrate to fail if there's unknown migration
|
||||
// IDs in the database
|
||||
ValidateUnknownMigrations bool
|
||||
}
|
||||
|
||||
// Migration represents a database migration (a modification to be made on the database).
|
||||
type Migration struct {
|
||||
// ID is the migration identifier. Usually a timestamp like "201601021504".
|
||||
ID string
|
||||
// Migrate is a function that will br executed while running this migration.
|
||||
Migrate MigrateFunc
|
||||
// Rollback will be executed on rollback. Can be nil.
|
||||
Rollback RollbackFunc
|
||||
}
|
||||
|
||||
// Gormigrate represents a collection of all migrations of a database schema.
|
||||
type Gormigrate struct {
|
||||
db *gorm.DB
|
||||
tx *gorm.DB
|
||||
options *Options
|
||||
migrations []*Migration
|
||||
initSchema InitSchemaFunc
|
||||
}
|
||||
|
||||
// ReservedIDError is returned when a migration is using a reserved ID
|
||||
type ReservedIDError struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func (e *ReservedIDError) Error() string {
|
||||
return fmt.Sprintf(`gormigrate: Reserved migration ID: "%s"`, e.ID)
|
||||
}
|
||||
|
||||
// DuplicatedIDError is returned when more than one migration have the same ID
|
||||
type DuplicatedIDError struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func (e *DuplicatedIDError) Error() string {
|
||||
return fmt.Sprintf(`gormigrate: Duplicated migration ID: "%s"`, e.ID)
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultOptions can be used if you don't want to think about options.
|
||||
DefaultOptions = &Options{
|
||||
TableName: "migrations",
|
||||
IDColumnName: "id",
|
||||
IDColumnSize: 255,
|
||||
SequenceColumnName: "sequence",
|
||||
AppliedAtColumnName: "applied_at",
|
||||
StatusColumnName: "status",
|
||||
UseTransaction: true,
|
||||
ValidateUnknownMigrations: false,
|
||||
}
|
||||
|
||||
// ErrRollbackImpossible is returned when trying to rollback a migration
|
||||
// that has no rollback function.
|
||||
ErrRollbackImpossible = errors.New("gormigrate: It's impossible to rollback this migration")
|
||||
|
||||
// ErrNoMigrationDefined is returned when no migration is defined.
|
||||
ErrNoMigrationDefined = errors.New("gormigrate: No migration defined")
|
||||
|
||||
// ErrMissingID is returned when the ID od migration is equal to ""
|
||||
ErrMissingID = errors.New("gormigrate: Missing ID in migration")
|
||||
|
||||
// ErrNoRunMigration is returned when any run migration was found while
|
||||
// running RollbackLast
|
||||
ErrNoRunMigration = errors.New("gormigrate: Could not find last run migration")
|
||||
|
||||
// ErrMigrationIDDoesNotExist is returned when migrating or rolling back to a migration ID that
|
||||
// does not exist in the list of migrations
|
||||
ErrMigrationIDDoesNotExist = errors.New("gormigrate: Tried to migrate to an ID that doesn't exist")
|
||||
|
||||
// ErrUnknownPastMigration is returned if a migration exists in the DB that doesn't exist in the code
|
||||
ErrUnknownPastMigration = errors.New("gormigrate: Found migration in DB that does not exist in code")
|
||||
)
|
||||
|
||||
// New returns a new Gormigrate.
|
||||
func New(db *gorm.DB, options *Options, migrations []*Migration) *Gormigrate {
|
||||
if options == nil {
|
||||
options = DefaultOptions
|
||||
}
|
||||
if options.TableName == "" {
|
||||
options.TableName = DefaultOptions.TableName
|
||||
}
|
||||
if options.IDColumnName == "" {
|
||||
options.IDColumnName = DefaultOptions.IDColumnName
|
||||
}
|
||||
if options.IDColumnSize == 0 {
|
||||
options.IDColumnSize = DefaultOptions.IDColumnSize
|
||||
}
|
||||
if options.SequenceColumnName == "" {
|
||||
options.SequenceColumnName = DefaultOptions.SequenceColumnName
|
||||
}
|
||||
if options.AppliedAtColumnName == "" {
|
||||
options.AppliedAtColumnName = DefaultOptions.AppliedAtColumnName
|
||||
}
|
||||
if options.StatusColumnName == "" {
|
||||
options.StatusColumnName = DefaultOptions.StatusColumnName
|
||||
}
|
||||
return &Gormigrate{
|
||||
db: db,
|
||||
options: options,
|
||||
migrations: migrations,
|
||||
}
|
||||
}
|
||||
|
||||
// InitSchema sets a function that is run if no migration is found.
|
||||
// The idea is preventing to run all migrations when a new clean database
|
||||
// is being migrating. In this function you should create all tables and
|
||||
// foreign key necessary to your application.
|
||||
func (g *Gormigrate) InitSchema(initSchema InitSchemaFunc) {
|
||||
g.initSchema = initSchema
|
||||
}
|
||||
|
||||
// Migrate executes all migrations that did not run yet.
|
||||
func (g *Gormigrate) Migrate() error {
|
||||
if !g.hasMigrations() {
|
||||
return ErrNoMigrationDefined
|
||||
}
|
||||
var targetMigrationID string
|
||||
if len(g.migrations) > 0 {
|
||||
targetMigrationID = g.migrations[len(g.migrations)-1].ID
|
||||
}
|
||||
return g.migrate(targetMigrationID)
|
||||
}
|
||||
|
||||
// MigrateTo executes all migrations that did not run yet up to the migration that matches `migrationID`.
|
||||
func (g *Gormigrate) MigrateTo(migrationID string) error {
|
||||
if err := g.checkIDExist(migrationID); err != nil {
|
||||
return err
|
||||
}
|
||||
return g.migrate(migrationID)
|
||||
}
|
||||
|
||||
func (g *Gormigrate) migrate(migrationID string) error {
|
||||
if !g.hasMigrations() {
|
||||
return ErrNoMigrationDefined
|
||||
}
|
||||
|
||||
if err := g.checkReservedID(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := g.checkDuplicatedID(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.begin()
|
||||
defer g.rollback()
|
||||
|
||||
if err := g.createMigrationTableIfNotExists(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if g.options.ValidateUnknownMigrations {
|
||||
unknownMigrations, err := g.unknownMigrationsHaveHappened()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if unknownMigrations {
|
||||
return ErrUnknownPastMigration
|
||||
}
|
||||
}
|
||||
|
||||
if g.initSchema != nil {
|
||||
canInitializeSchema, err := g.canInitializeSchema()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if canInitializeSchema {
|
||||
if err := g.runInitSchema(); err != nil {
|
||||
return err
|
||||
}
|
||||
return g.commit()
|
||||
}
|
||||
}
|
||||
|
||||
for _, migration := range g.migrations {
|
||||
if err := g.runMigration(migration); err != nil {
|
||||
return err
|
||||
}
|
||||
if migrationID != "" && migration.ID == migrationID {
|
||||
break
|
||||
}
|
||||
}
|
||||
return g.commit()
|
||||
}
|
||||
|
||||
// There are migrations to apply if either there's a defined
|
||||
// initSchema function or if the list of migrations is not empty.
|
||||
func (g *Gormigrate) hasMigrations() bool {
|
||||
return g.initSchema != nil || len(g.migrations) > 0
|
||||
}
|
||||
|
||||
// Check whether any migration is using a reserved ID.
|
||||
// For now there's only have one reserved ID, but there may be more in the future.
|
||||
func (g *Gormigrate) checkReservedID() error {
|
||||
for _, m := range g.migrations {
|
||||
if m.ID == initSchemaMigrationID {
|
||||
return &ReservedIDError{ID: m.ID}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) checkDuplicatedID() error {
|
||||
lookup := make(map[string]struct{}, len(g.migrations))
|
||||
for _, m := range g.migrations {
|
||||
if _, ok := lookup[m.ID]; ok {
|
||||
return &DuplicatedIDError{ID: m.ID}
|
||||
}
|
||||
lookup[m.ID] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) checkIDExist(migrationID string) error {
|
||||
for _, migrate := range g.migrations {
|
||||
if migrate.ID == migrationID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrMigrationIDDoesNotExist
|
||||
}
|
||||
|
||||
// RollbackLast undo the last migration
|
||||
func (g *Gormigrate) RollbackLast() error {
|
||||
if len(g.migrations) == 0 {
|
||||
return ErrNoMigrationDefined
|
||||
}
|
||||
|
||||
g.begin()
|
||||
defer g.rollback()
|
||||
|
||||
lastRunMigration, err := g.getLastRunMigration()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := g.rollbackMigration(lastRunMigration); err != nil {
|
||||
return err
|
||||
}
|
||||
return g.commit()
|
||||
}
|
||||
|
||||
// RollbackTo undoes migrations up to the given migration that matches the `migrationID`.
|
||||
// Migration with the matching `migrationID` is not rolled back.
|
||||
func (g *Gormigrate) RollbackTo(migrationID string) error {
|
||||
if len(g.migrations) == 0 {
|
||||
return ErrNoMigrationDefined
|
||||
}
|
||||
|
||||
if err := g.checkIDExist(migrationID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.begin()
|
||||
defer g.rollback()
|
||||
|
||||
for i := len(g.migrations) - 1; i >= 0; i-- {
|
||||
migration := g.migrations[i]
|
||||
if migration.ID == migrationID {
|
||||
break
|
||||
}
|
||||
migrationRan, err := g.migrationRan(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if migrationRan {
|
||||
if err := g.rollbackMigration(migration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return g.commit()
|
||||
}
|
||||
|
||||
func (g *Gormigrate) getLastRunMigration() (*Migration, error) {
|
||||
for i := len(g.migrations) - 1; i >= 0; i-- {
|
||||
migration := g.migrations[i]
|
||||
|
||||
migrationRan, err := g.migrationRan(migration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if migrationRan {
|
||||
return migration, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrNoRunMigration
|
||||
}
|
||||
|
||||
// RollbackMigration undo a migration.
|
||||
func (g *Gormigrate) RollbackMigration(m *Migration) error {
|
||||
g.begin()
|
||||
defer g.rollback()
|
||||
|
||||
if err := g.rollbackMigration(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return g.commit()
|
||||
}
|
||||
|
||||
func (g *Gormigrate) rollbackMigration(m *Migration) error {
|
||||
if m.Rollback == nil {
|
||||
return ErrRollbackImpossible
|
||||
}
|
||||
|
||||
if err := m.Rollback(g.tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cond := fmt.Sprintf("%s = ?", g.options.IDColumnName)
|
||||
return g.tx.Table(g.options.TableName).Where(cond, m.ID).Delete(g.model()).Error
|
||||
}
|
||||
|
||||
func (g *Gormigrate) runInitSchema() error {
|
||||
if err := g.initSchema(g.tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := g.insertMigration(initSchemaMigrationID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, migration := range g.migrations {
|
||||
if err := g.insertMigration(migration.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) runMigration(migration *Migration) error {
|
||||
if len(migration.ID) == 0 {
|
||||
return ErrMissingID
|
||||
}
|
||||
|
||||
migrationRan, err := g.migrationRan(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !migrationRan {
|
||||
if err := migration.Migrate(g.tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := g.insertMigration(migration.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// model returns pointer to dynamically created gorm migration model struct value
|
||||
func (g *Gormigrate) model() any {
|
||||
fields := []reflect.StructField{
|
||||
{
|
||||
Name: "ID",
|
||||
Type: reflect.TypeOf(""),
|
||||
Tag: reflect.StructTag(fmt.Sprintf(
|
||||
`gorm:"primaryKey;column:%s;size:%d"`,
|
||||
g.options.IDColumnName,
|
||||
g.options.IDColumnSize,
|
||||
)),
|
||||
},
|
||||
{
|
||||
Name: "Sequence",
|
||||
Type: reflect.TypeOf(int64(0)),
|
||||
Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s"`, g.options.SequenceColumnName)),
|
||||
},
|
||||
{
|
||||
Name: "AppliedAt",
|
||||
Type: reflect.TypeOf(time.Time{}),
|
||||
Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s"`, g.options.AppliedAtColumnName)),
|
||||
},
|
||||
{
|
||||
Name: "Status",
|
||||
Type: reflect.TypeOf(""),
|
||||
Tag: reflect.StructTag(fmt.Sprintf(`gorm:"column:%s;size:20"`, g.options.StatusColumnName)),
|
||||
},
|
||||
}
|
||||
structType := reflect.StructOf(fields)
|
||||
structValue := reflect.New(structType).Elem()
|
||||
return structValue.Addr().Interface()
|
||||
}
|
||||
|
||||
func (g *Gormigrate) createMigrationTableIfNotExists() error {
|
||||
if err := g.tx.Table(g.options.TableName).AutoMigrate(g.model()); err != nil {
|
||||
return err
|
||||
}
|
||||
return g.backfillMigrationMetadata()
|
||||
}
|
||||
|
||||
// backfillMigrationMetadata populates sequence, applied_at, and status for
|
||||
// rows that predate the addition of these columns (all marked as success
|
||||
// with the same timestamp). Rows are sequenced by their natural insertion
|
||||
// order (rowid for SQLite, ctid for PostgreSQL) so that the sequence column
|
||||
// reflects the actual order migrations were originally applied.
|
||||
func (g *Gormigrate) backfillMigrationMetadata() error {
|
||||
var orderCol string
|
||||
switch g.tx.Dialector.Name() {
|
||||
case "sqlite":
|
||||
orderCol = "rowid"
|
||||
case "postgres":
|
||||
orderCol = "ctid"
|
||||
default:
|
||||
orderCol = g.options.IDColumnName
|
||||
}
|
||||
|
||||
var ids []string
|
||||
err := g.tx.Table(g.options.TableName).
|
||||
Where(fmt.Sprintf("%s IS NULL OR %s = ''", g.options.StatusColumnName, g.options.StatusColumnName)).
|
||||
Order(orderCol).
|
||||
Pluck(g.options.IDColumnName, &ids).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
var maxSeq int64
|
||||
if err := g.tx.Table(g.options.TableName).
|
||||
Select(fmt.Sprintf("COALESCE(MAX(%s), 0)", g.options.SequenceColumnName)).
|
||||
Scan(&maxSeq).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ids {
|
||||
err := g.tx.Table(g.options.TableName).
|
||||
Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), id).
|
||||
Updates(map[string]interface{}{
|
||||
g.options.SequenceColumnName: maxSeq + int64(i) + 1,
|
||||
g.options.AppliedAtColumnName: now,
|
||||
g.options.StatusColumnName: "success",
|
||||
}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) migrationRan(m *Migration) (bool, error) {
|
||||
var count int64
|
||||
err := g.tx.
|
||||
Table(g.options.TableName).
|
||||
Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID).
|
||||
Count(&count).
|
||||
Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// The schema can be initialised only if it hasn't been initialised yet
|
||||
// and no other migration has been applied already.
|
||||
func (g *Gormigrate) canInitializeSchema() (bool, error) {
|
||||
migrationRan, err := g.migrationRan(&Migration{ID: initSchemaMigrationID})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if migrationRan {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If the ID doesn't exist, we also want the list of migrations to be empty
|
||||
var count int64
|
||||
err = g.tx.
|
||||
Table(g.options.TableName).
|
||||
Count(&count).
|
||||
Error
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
func (g *Gormigrate) unknownMigrationsHaveHappened() (bool, error) {
|
||||
rows, err := g.tx.Table(g.options.TableName).Select(g.options.IDColumnName).Rows()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
g.tx.Logger.Error(context.TODO(), err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
validIDSet := make(map[string]struct{}, len(g.migrations)+1)
|
||||
validIDSet[initSchemaMigrationID] = struct{}{}
|
||||
for _, migration := range g.migrations {
|
||||
validIDSet[migration.ID] = struct{}{}
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var pastMigrationID string
|
||||
if err := rows.Scan(&pastMigrationID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if _, ok := validIDSet[pastMigrationID]; !ok {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) nextSequence() (int64, error) {
|
||||
var maxSeq int64
|
||||
err := g.tx.Table(g.options.TableName).
|
||||
Select(fmt.Sprintf("COALESCE(MAX(%s), 0)", g.options.SequenceColumnName)).
|
||||
Scan(&maxSeq).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxSeq + 1, nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) insertMigration(id string) error {
|
||||
seq, err := g.nextSequence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record := g.model()
|
||||
v := reflect.ValueOf(record).Elem()
|
||||
v.FieldByName("ID").SetString(id)
|
||||
v.FieldByName("Sequence").SetInt(seq)
|
||||
v.FieldByName("AppliedAt").Set(reflect.ValueOf(time.Now()))
|
||||
v.FieldByName("Status").SetString("success")
|
||||
return g.tx.Table(g.options.TableName).Create(record).Error
|
||||
}
|
||||
|
||||
func (g *Gormigrate) begin() {
|
||||
if g.options.UseTransaction {
|
||||
g.tx = g.db.Begin()
|
||||
} else {
|
||||
g.tx = g.db
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gormigrate) commit() error {
|
||||
if g.options.UseTransaction {
|
||||
return g.tx.Commit().Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gormigrate) rollback() {
|
||||
if g.options.UseTransaction {
|
||||
g.tx.Rollback()
|
||||
}
|
||||
}
|
||||
223
framework/modelcatalog/capabilities_test.go
Normal file
223
framework/modelcatalog/capabilities_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_PrefersChatThenResponsesThenCompletion(t *testing.T) {
|
||||
contextLengthChat := 128000
|
||||
maxInputTokensChat := 64000
|
||||
maxOutputTokensChat := 16000
|
||||
modality := "text"
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o", "openai", "responses"): {
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "responses",
|
||||
ContextLength: capabilityIntPtr(200000),
|
||||
MaxInputTokens: capabilityIntPtr(100000),
|
||||
MaxOutputTokens: capabilityIntPtr(32000),
|
||||
},
|
||||
makeKey("gpt-4o", "openai", "chat"): {
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &contextLengthChat,
|
||||
MaxInputTokens: &maxInputTokensChat,
|
||||
MaxOutputTokens: &maxOutputTokensChat,
|
||||
Architecture: &schemas.Architecture{
|
||||
Modality: &modality,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry")
|
||||
}
|
||||
if entry.Mode != "chat" {
|
||||
t.Fatalf("expected chat mode to win, got %q", entry.Mode)
|
||||
}
|
||||
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
|
||||
t.Fatalf("expected context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
|
||||
}
|
||||
if entry.MaxInputTokens == nil || *entry.MaxInputTokens != maxInputTokensChat {
|
||||
t.Fatalf("expected max_input_tokens=%d, got %#v", maxInputTokensChat, entry.MaxInputTokens)
|
||||
}
|
||||
if entry.MaxOutputTokens == nil || *entry.MaxOutputTokens != maxOutputTokensChat {
|
||||
t.Fatalf("expected max_output_tokens=%d, got %#v", maxOutputTokensChat, entry.MaxOutputTokens)
|
||||
}
|
||||
if entry.Architecture == nil || entry.Architecture.Modality == nil || *entry.Architecture.Modality != modality {
|
||||
t.Fatalf("expected architecture modality=%q, got %#v", modality, entry.Architecture)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_FallsBackToAnyModeDeterministically(t *testing.T) {
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("imagen", "vertex", "image_generation"): {
|
||||
Model: "imagen",
|
||||
Provider: "vertex",
|
||||
Mode: "image_generation",
|
||||
ContextLength: capabilityIntPtr(4096),
|
||||
MaxOutputTokens: capabilityIntPtr(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("imagen", schemas.Vertex)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry")
|
||||
}
|
||||
if entry.Mode != "image_generation" {
|
||||
t.Fatalf("expected image_generation fallback, got %q", entry.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_ResolvesAliasFamilyViaBaseModel(t *testing.T) {
|
||||
contextLengthChat := 128000
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "responses"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "responses",
|
||||
ContextLength: capabilityIntPtr(64000),
|
||||
MaxOutputTokens: capabilityIntPtr(8000),
|
||||
},
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &contextLengthChat,
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
},
|
||||
},
|
||||
baseModelIndex: map[string]string{
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry for base-model alias")
|
||||
}
|
||||
if entry.Mode != "chat" {
|
||||
t.Fatalf("expected chat mode to win for alias family, got %q", entry.Mode)
|
||||
}
|
||||
if entry.ContextLength == nil || *entry.ContextLength != contextLengthChat {
|
||||
t.Fatalf("expected alias family context_length=%d, got %#v", contextLengthChat, entry.ContextLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_ResolvesProviderPrefixedAlias(t *testing.T) {
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: capabilityIntPtr(128000),
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
},
|
||||
},
|
||||
baseModelIndex: map[string]string{
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("openai/gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected capability entry for provider-prefixed alias")
|
||||
}
|
||||
if entry.Mode != "chat" {
|
||||
t.Fatalf("expected chat mode for provider-prefixed alias, got %q", entry.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelCapabilityEntryForModel_PrefersLiteralMatchOverAliasFamily(t *testing.T) {
|
||||
literalContextLength := 32000
|
||||
aliasContextLength := 128000
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingData: map[string]configstoreTables.TableModelPricing{
|
||||
makeKey("gpt-4o", "openai", "chat"): {
|
||||
Model: "gpt-4o",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &literalContextLength,
|
||||
MaxOutputTokens: capabilityIntPtr(4000),
|
||||
},
|
||||
makeKey("gpt-4o-2024-08-06", "openai", "chat"): {
|
||||
Model: "gpt-4o-2024-08-06",
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
ContextLength: &aliasContextLength,
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
},
|
||||
},
|
||||
baseModelIndex: map[string]string{
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
},
|
||||
}
|
||||
|
||||
entry := mc.GetModelCapabilityEntryForModel("gpt-4o", schemas.OpenAI)
|
||||
if entry == nil {
|
||||
t.Fatal("expected literal capability entry")
|
||||
}
|
||||
if entry.ContextLength == nil || *entry.ContextLength != literalContextLength {
|
||||
t.Fatalf("expected literal match to win with context_length=%d, got %#v", literalContextLength, entry.ContextLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilityFieldsRoundTripThroughPricingConversions(t *testing.T) {
|
||||
modality := "text"
|
||||
inputCost := float64(1)
|
||||
outputCost := float64(2)
|
||||
entry := PricingEntry{
|
||||
BaseModel: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
PricingOptions: PricingOptions{
|
||||
InputCostPerToken: &inputCost,
|
||||
OutputCostPerToken: &outputCost,
|
||||
},
|
||||
ContextLength: capabilityIntPtr(128000),
|
||||
MaxInputTokens: capabilityIntPtr(64000),
|
||||
MaxOutputTokens: capabilityIntPtr(16000),
|
||||
Architecture: &schemas.Architecture{
|
||||
Modality: &modality,
|
||||
},
|
||||
}
|
||||
|
||||
table := convertPricingDataToTableModelPricing("gpt-4o", entry)
|
||||
roundTrip := convertTableModelPricingToPricingData(&table)
|
||||
|
||||
if roundTrip.ContextLength == nil || *roundTrip.ContextLength != 128000 {
|
||||
t.Fatalf("expected context_length to round-trip, got %#v", roundTrip.ContextLength)
|
||||
}
|
||||
if roundTrip.MaxInputTokens == nil || *roundTrip.MaxInputTokens != 64000 {
|
||||
t.Fatalf("expected max_input_tokens to round-trip, got %#v", roundTrip.MaxInputTokens)
|
||||
}
|
||||
if roundTrip.MaxOutputTokens == nil || *roundTrip.MaxOutputTokens != 16000 {
|
||||
t.Fatalf("expected max_output_tokens to round-trip, got %#v", roundTrip.MaxOutputTokens)
|
||||
}
|
||||
if roundTrip.Architecture == nil || roundTrip.Architecture.Modality == nil || *roundTrip.Architecture.Modality != modality {
|
||||
t.Fatalf("expected architecture to round-trip, got %#v", roundTrip.Architecture)
|
||||
}
|
||||
}
|
||||
|
||||
func capabilityIntPtr(v int) *int { return &v }
|
||||
29
framework/modelcatalog/config.go
Normal file
29
framework/modelcatalog/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultSyncInterval = 24 * time.Hour
|
||||
MinimumPricingSyncIntervalSec = int64(3600)
|
||||
|
||||
// syncWorkerTickerPeriod is the fixed interval at which the background sync worker
|
||||
// wakes up to check whether a sync is due. This is independent of pricingSyncInterval —
|
||||
// the ticker defines the check granularity, not the sync frequency.
|
||||
// Setting pricingSyncInterval below this value has no effect on actual sync frequency.
|
||||
syncWorkerTickerPeriod = 1 * time.Hour
|
||||
|
||||
ConfigLastPricingSyncKey = "LastModelPricingSync"
|
||||
ConfigLastParamsSyncKey = "LastModelParametersSync"
|
||||
DefaultPricingURL = "https://getbifrost.ai/datasheet"
|
||||
DefaultModelParametersURL = "https://getbifrost.ai/datasheet/model-parameters"
|
||||
DefaultPricingTimeout = 45 * time.Second
|
||||
DefaultModelParametersTimeout = 45 * time.Second
|
||||
)
|
||||
|
||||
// Config is the model pricing configuration.
|
||||
type Config struct {
|
||||
PricingURL *string `json:"pricing_url,omitempty"`
|
||||
PricingSyncInterval *int64 `json:"pricing_sync_interval,omitempty"` // seconds
|
||||
}
|
||||
459
framework/modelcatalog/main.go
Normal file
459
framework/modelcatalog/main.go
Normal file
@@ -0,0 +1,459 @@
|
||||
// Package modelcatalog provides a pricing manager for the framework.
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
type ModelCatalog struct {
|
||||
configStore configstore.ConfigStore
|
||||
distributedLockManager *configstore.DistributedLockManager
|
||||
|
||||
logger schemas.Logger
|
||||
|
||||
// Configuration fields (protected by syncMu)
|
||||
pricingURL string
|
||||
syncInterval time.Duration
|
||||
lastSyncedAt time.Time
|
||||
syncMu sync.RWMutex
|
||||
|
||||
shouldSyncGate func(ctx context.Context) bool
|
||||
afterSyncHook func(ctx context.Context)
|
||||
|
||||
// In-memory cache for fast access - direct map for O(1) lookups
|
||||
pricingData map[string]configstoreTables.TableModelPricing
|
||||
mu sync.RWMutex
|
||||
|
||||
// rawOverrides is the canonical list of all active overrides. It exists solely
|
||||
// to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride
|
||||
// iterate over it to rebuild the list, then derive customPricing from it.
|
||||
// customPricing is the actual lookup structure used at query time.
|
||||
rawOverrides []PricingOverride
|
||||
customPricing *customPricingData
|
||||
overridesMu sync.RWMutex
|
||||
|
||||
modelPool map[schemas.ModelProvider][]string
|
||||
unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering
|
||||
baseModelIndex map[string]string // model string → canonical base model name
|
||||
|
||||
// Pre-parsed supported response types index (keyed by model name)
|
||||
// Values are normalized response types: "chat_completion", "responses", "text_completion"
|
||||
supportedResponseTypes map[string][]string
|
||||
|
||||
// Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters)
|
||||
// Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools")
|
||||
supportedParams map[string][]string
|
||||
|
||||
// Background sync worker
|
||||
syncTicker *time.Ticker
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
syncCtx context.Context
|
||||
syncCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Init initializes the model catalog
|
||||
func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) {
|
||||
// Initialize pricing URL and sync interval
|
||||
pricingURL := DefaultPricingURL
|
||||
if config.PricingURL != nil {
|
||||
pricingURL = *config.PricingURL
|
||||
}
|
||||
syncInterval := DefaultSyncInterval
|
||||
if config.PricingSyncInterval != nil {
|
||||
syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
|
||||
}
|
||||
|
||||
// Log the active interval and the scheduler's actual check frequency so operators
|
||||
// are not surprised that setting interval=1h does not mean checks happen every second.
|
||||
// Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval.
|
||||
logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod)
|
||||
|
||||
mc := &ModelCatalog{
|
||||
pricingURL: pricingURL,
|
||||
syncInterval: syncInterval,
|
||||
configStore: configStore,
|
||||
logger: logger,
|
||||
pricingData: make(map[string]configstoreTables.TableModelPricing),
|
||||
modelPool: make(map[schemas.ModelProvider][]string),
|
||||
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
|
||||
baseModelIndex: make(map[string]string),
|
||||
supportedResponseTypes: make(map[string][]string),
|
||||
supportedParams: make(map[string][]string),
|
||||
done: make(chan struct{}),
|
||||
distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)),
|
||||
}
|
||||
|
||||
// Initialize syncCtx early so background startup goroutines can use it and
|
||||
// Cleanup() can cancel them. startSyncWorker is still called at the end after
|
||||
// cold-start paths have completed.
|
||||
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
|
||||
|
||||
// If Init returns an error the caller never owns mc and will never call
|
||||
// Cleanup(), so cancel syncCtx to stop any background goroutines that were
|
||||
// already spawned before the failure.
|
||||
initSucceeded := false
|
||||
defer func() {
|
||||
if !initSucceeded {
|
||||
mc.syncCancel()
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("initializing model catalog...")
|
||||
if configStore != nil {
|
||||
// Per-model lazy load when the in-memory cache misses (eviction, new models, or if
|
||||
// startup bulk load was skipped). loadModelParametersFromDatabase still bulk-warms
|
||||
// the cache on init and on ReloadFromDB so common paths avoid a DB read per model.
|
||||
providerUtils.SetCacheMissHandler(func(model string) *providerUtils.ModelParams {
|
||||
missCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
params, err := configStore.GetModelParametersByModel(missCtx, model)
|
||||
if err != nil || params == nil {
|
||||
return nil
|
||||
}
|
||||
var p struct {
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(params.Data), &p); err != nil || p.MaxOutputTokens == nil {
|
||||
return nil
|
||||
}
|
||||
return &providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
||||
})
|
||||
var wg sync.WaitGroup
|
||||
var pricingErr, paramsErr error
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to load initial pricing data: %w", err)
|
||||
return
|
||||
}
|
||||
mc.mu.RLock()
|
||||
hasPricingData := len(mc.pricingData) > 0
|
||||
mc.mu.RUnlock()
|
||||
if hasPricingData {
|
||||
mc.logger.Info("existing pricing data found in database, syncing from URL in background")
|
||||
mc.wg.Add(1)
|
||||
go func() {
|
||||
defer mc.wg.Done()
|
||||
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_pricing_startup_sync", 10, func() error {
|
||||
return mc.syncPricing(mc.syncCtx)
|
||||
}); err != nil {
|
||||
mc.logger.Warn("background startup pricing sync failed: %v", err)
|
||||
} else {
|
||||
mc.logger.Info("background startup pricing sync completed successfully")
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_startup_sync", 10, func() error {
|
||||
return mc.syncPricing(ctx)
|
||||
}); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := mc.loadModelParametersFromDatabase(ctx)
|
||||
if err != nil {
|
||||
paramsErr = fmt.Errorf("failed to load initial model parameters: %w", err)
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
mc.logger.Info("existing model parameters found in database (%d records), syncing from URL in background", n)
|
||||
mc.wg.Add(1)
|
||||
go func() {
|
||||
defer mc.wg.Done()
|
||||
if err := mc.withDistributedLock(mc.syncCtx, "model_catalog_params_startup_sync", 10, func() error {
|
||||
return mc.syncModelParameters(mc.syncCtx)
|
||||
}); err != nil {
|
||||
mc.logger.Warn("background startup model parameters sync failed: %v", err)
|
||||
} else {
|
||||
mc.logger.Info("background startup model parameters sync completed successfully")
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
if err := mc.withDistributedLock(ctx, "model_catalog_params_startup_sync", 10, func() error {
|
||||
return mc.syncModelParameters(ctx)
|
||||
}); err != nil {
|
||||
paramsErr = fmt.Errorf("failed to sync model parameters data: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
if pricingErr != nil {
|
||||
return nil, pricingErr
|
||||
}
|
||||
if paramsErr != nil {
|
||||
return nil, paramsErr
|
||||
}
|
||||
} else {
|
||||
// Load pricing and model parameters from URL into memory (no config store)
|
||||
if err := mc.loadPricingIntoMemoryFromURL(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
|
||||
}
|
||||
if err := mc.loadModelParametersIntoMemoryFromURL(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load model parameters from URL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
mc.syncMu.Lock()
|
||||
mc.lastSyncedAt = time.Now()
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
// Populate model pool with normalized providers from pricing data
|
||||
mc.populateModelPoolFromPricingData()
|
||||
|
||||
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to load pricing overrides: %w", err)
|
||||
}
|
||||
|
||||
// Start background sync worker
|
||||
mc.startSyncWorker(mc.syncCtx)
|
||||
initSucceeded = true
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) SetShouldSyncGate(shouldSyncGate func(ctx context.Context) bool) {
|
||||
mc.shouldSyncGate = shouldSyncGate
|
||||
}
|
||||
|
||||
// SetAfterSyncHook registers a callback invoked after every successful URL → DB pricing sync.
|
||||
// In enterprise this is used to broadcast a gossip message so other pods reload from DB.
|
||||
func (mc *ModelCatalog) SetAfterSyncHook(fn func(ctx context.Context)) {
|
||||
mc.afterSyncHook = fn
|
||||
}
|
||||
|
||||
// ReloadFromDB reloads the in-memory pricing cache and model-parameters provider cache from the database.
|
||||
// In enterprise this is called on non-leader pods when they receive a gossip sync notification.
|
||||
func (mc *ModelCatalog) ReloadFromDB(ctx context.Context) error {
|
||||
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
mc.populateModelPoolFromPricingData()
|
||||
_, err := mc.loadModelParametersFromDatabase(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateSyncConfig updates the pricing URL and sync interval, restarts the background sync worker,
|
||||
// then delegates to ForceReloadPricing for a full sync cycle.
|
||||
func (mc *ModelCatalog) UpdateSyncConfig(ctx context.Context, config *Config) error {
|
||||
// Acquire pricing mutex to update configuration atomically
|
||||
mc.syncMu.Lock()
|
||||
|
||||
// Stop existing sync worker before updating configuration
|
||||
if mc.syncCancel != nil {
|
||||
mc.syncCancel()
|
||||
}
|
||||
if mc.syncTicker != nil {
|
||||
mc.syncTicker.Stop()
|
||||
}
|
||||
|
||||
// Update pricing configuration
|
||||
mc.pricingURL = DefaultPricingURL
|
||||
if config.PricingURL != nil {
|
||||
mc.pricingURL = *config.PricingURL
|
||||
}
|
||||
|
||||
mc.syncInterval = DefaultSyncInterval
|
||||
if config.PricingSyncInterval != nil {
|
||||
mc.syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second
|
||||
}
|
||||
|
||||
// Create new sync worker with updated configuration
|
||||
mc.syncCtx, mc.syncCancel = context.WithCancel(ctx)
|
||||
mc.startSyncWorker(mc.syncCtx)
|
||||
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
// Delegate to ForceReloadPricing for a complete sync cycle
|
||||
return mc.ForceReloadPricing(ctx)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error {
|
||||
timeout := DefaultPricingTimeout
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Run pricing sync and model parameters sync in parallel
|
||||
var wg sync.WaitGroup
|
||||
var pricingErr, paramsErr error
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncPricing(ctx); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to sync pricing data: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Rebuild model pool from updated pricing data
|
||||
mc.populateModelPoolFromPricingData()
|
||||
|
||||
if err := mc.loadPricingOverridesFromStore(ctx); err != nil {
|
||||
pricingErr = fmt.Errorf("failed to load pricing overrides: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncModelParameters(ctx); err != nil {
|
||||
paramsErr = fmt.Errorf("failed to sync model parameters: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
if pricingErr != nil {
|
||||
return pricingErr
|
||||
}
|
||||
if paramsErr != nil {
|
||||
return paramsErr
|
||||
}
|
||||
|
||||
if mc.afterSyncHook != nil {
|
||||
mc.afterSyncHook(ctx)
|
||||
}
|
||||
|
||||
mc.syncMu.Lock()
|
||||
// Reset the ticker so the next scheduled sync waits a full interval from now
|
||||
if mc.syncTicker != nil {
|
||||
mc.syncTicker.Reset(mc.syncInterval)
|
||||
}
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPricingURL returns a copy of the pricing URL under mutex protection
|
||||
func (mc *ModelCatalog) getPricingURL() string {
|
||||
mc.syncMu.RLock()
|
||||
defer mc.syncMu.RUnlock()
|
||||
return mc.pricingURL
|
||||
}
|
||||
|
||||
// IsRequestTypeSupported checks if a model supports chat completion.
|
||||
// It checks the supportedResponseTypes index.
|
||||
func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
outputs, ok := mc.supportedResponseTypes[model]
|
||||
return ok && slices.Contains(outputs, string(requestType))
|
||||
}
|
||||
|
||||
// GetSupportedParameters returns the list of supported parameter names for a model.
|
||||
// Returns nil if the model is not found in the catalog.
|
||||
func (mc *ModelCatalog) GetSupportedParameters(model string) []string {
|
||||
mc.mu.RLock()
|
||||
params, ok := mc.supportedParams[model]
|
||||
mc.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(params))
|
||||
copy(result, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// populateModelPool populates the model pool with all available models per provider (thread-safe)
|
||||
func (mc *ModelCatalog) populateModelPoolFromPricingData() {
|
||||
// Acquire write lock for the entire rebuild operation
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Clear existing model pool and base model index
|
||||
mc.modelPool = make(map[schemas.ModelProvider][]string)
|
||||
mc.unfilteredModelPool = make(map[schemas.ModelProvider][]string)
|
||||
mc.baseModelIndex = make(map[string]string)
|
||||
|
||||
// Map to track unique models per provider
|
||||
providerModels := make(map[schemas.ModelProvider]map[string]bool)
|
||||
|
||||
// Iterate through all pricing data to collect models per provider
|
||||
for _, pricing := range mc.pricingData {
|
||||
// Normalize provider before adding to model pool
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
|
||||
// Initialize map for this provider if not exists
|
||||
if providerModels[normalizedProvider] == nil {
|
||||
providerModels[normalizedProvider] = make(map[string]bool)
|
||||
}
|
||||
|
||||
// Add model to the provider's model set (using map for deduplication)
|
||||
providerModels[normalizedProvider][pricing.Model] = true
|
||||
|
||||
// Build base model index from pre-computed base_model field
|
||||
if pricing.BaseModel != "" {
|
||||
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
|
||||
}
|
||||
}
|
||||
|
||||
// Convert sets to slices and assign to modelPool
|
||||
for provider, modelSet := range providerModels {
|
||||
models := make([]string, 0, len(modelSet))
|
||||
for model := range modelSet {
|
||||
models = append(models, model)
|
||||
}
|
||||
mc.modelPool[provider] = models
|
||||
mc.unfilteredModelPool[provider] = models
|
||||
}
|
||||
|
||||
// Log the populated model pool for debugging
|
||||
totalModels := 0
|
||||
for provider, models := range mc.modelPool {
|
||||
totalModels += len(models)
|
||||
mc.logger.Debug("populated %d models for provider %s", len(models), string(provider))
|
||||
}
|
||||
mc.logger.Info("populated model pool with %d models across %d providers", totalModels, len(mc.modelPool))
|
||||
}
|
||||
|
||||
// Cleanup cleans up the model catalog
|
||||
func (mc *ModelCatalog) Cleanup() error {
|
||||
if mc.syncCancel != nil {
|
||||
mc.syncCancel()
|
||||
}
|
||||
|
||||
mc.syncMu.Lock()
|
||||
if mc.syncTicker != nil {
|
||||
mc.syncTicker.Stop()
|
||||
}
|
||||
mc.syncMu.Unlock()
|
||||
|
||||
close(mc.done)
|
||||
mc.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTestCatalog creates a minimal ModelCatalog for testing purposes.
|
||||
// It does not start background sync workers or connect to external services.
|
||||
func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog {
|
||||
if baseModelIndex == nil {
|
||||
baseModelIndex = make(map[string]string)
|
||||
}
|
||||
return &ModelCatalog{
|
||||
modelPool: make(map[schemas.ModelProvider][]string),
|
||||
unfilteredModelPool: make(map[schemas.ModelProvider][]string),
|
||||
baseModelIndex: baseModelIndex,
|
||||
pricingData: make(map[string]configstoreTables.TableModelPricing),
|
||||
supportedResponseTypes: make(map[string][]string),
|
||||
supportedParams: make(map[string][]string),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
209
framework/modelcatalog/main_test.go
Normal file
209
framework/modelcatalog/main_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// newTestCatalog creates a minimal ModelCatalog for testing within the package.
|
||||
func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex map[string]string) *ModelCatalog {
|
||||
if modelPool == nil {
|
||||
modelPool = make(map[schemas.ModelProvider][]string)
|
||||
}
|
||||
if baseModelIndex == nil {
|
||||
baseModelIndex = make(map[string]string)
|
||||
}
|
||||
return &ModelCatalog{
|
||||
modelPool: modelPool,
|
||||
baseModelIndex: baseModelIndex,
|
||||
pricingData: make(map[string]configstoreTables.TableModelPricing),
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetBaseModelName tests ---
|
||||
|
||||
func TestGetBaseModelName_Simple(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// No catalog data, no prefix — returns as-is (no date suffix to strip either)
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_Prefixed(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// Provider prefix stripped, no catalog — algorithmic fallback returns base
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_PrefixedAnthropic(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.Equal(t, "claude-3-5-sonnet", mc.GetBaseModelName("anthropic/claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_FromCatalog(t *testing.T) {
|
||||
// Model has a pre-computed base_model in the catalog
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
})
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o"))
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_ProviderPrefixWithCatalog(t *testing.T) {
|
||||
// Model has provider prefix — strip prefix, then find in catalog
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"gpt-4o": "gpt-4o",
|
||||
})
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_FallbackAlgorithmic(t *testing.T) {
|
||||
// Model NOT in catalog — falls back to schemas.BaseModelName (date stripping)
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// Anthropic-style date suffix
|
||||
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("claude-sonnet-4-20250514"))
|
||||
// OpenAI-style date suffix
|
||||
assert.Equal(t, "gpt-4o", mc.GetBaseModelName("gpt-4o-2024-08-06"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_FallbackAlgorithmicWithPrefix(t *testing.T) {
|
||||
// Provider prefix + not in catalog — strip prefix, then algorithmic fallback
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.Equal(t, "claude-sonnet-4", mc.GetBaseModelName("anthropic/claude-sonnet-4-20250514"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_UnknownModel(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.Equal(t, "some-random-model", mc.GetBaseModelName("some-random-model"))
|
||||
}
|
||||
|
||||
func TestGetBaseModelName_CatalogTakesPrecedence(t *testing.T) {
|
||||
// If catalog says the base_model is X, use it even if algorithmic would give Y
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"my-custom-model-20250101": "my-custom-model-20250101", // catalog says keep the date
|
||||
})
|
||||
assert.Equal(t, "my-custom-model-20250101", mc.GetBaseModelName("my-custom-model-20250101"))
|
||||
}
|
||||
|
||||
// --- IsSameModel tests ---
|
||||
|
||||
func TestIsSameModel_DirectMatch(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("gpt-4o", "gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_ProviderPrefix(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("openai/gpt-4o", "gpt-4o"))
|
||||
assert.True(t, mc.IsSameModel("gpt-4o", "openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_BothPrefixed(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("openai/gpt-4o", "openai/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_DifferentProvidersSameBase(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
// Both have the same base model after stripping different provider prefixes
|
||||
assert.True(t, mc.IsSameModel("openai/gpt-4o", "azure/gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_DifferentModels(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.False(t, mc.IsSameModel("gpt-4o", "claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_DifferentModelsBothPrefixed(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.False(t, mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_CatalogBacked(t *testing.T) {
|
||||
// Two model strings that look different but the catalog says they have the same base_model
|
||||
mc := newTestCatalog(nil, map[string]string{
|
||||
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
||||
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
|
||||
})
|
||||
assert.True(t, mc.IsSameModel("claude-3-5-sonnet", "claude-3-5-sonnet-20241022"))
|
||||
assert.True(t, mc.IsSameModel("claude-3-5-sonnet-20241022", "claude-3-5-sonnet"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_AlgorithmicFallback(t *testing.T) {
|
||||
// Models not in catalog — use algorithmic date stripping
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("custom-model-20250101", "custom-model"))
|
||||
}
|
||||
|
||||
func TestIsSameModel_EmptyStrings(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
assert.True(t, mc.IsSameModel("", ""))
|
||||
assert.False(t, mc.IsSameModel("gpt-4o", ""))
|
||||
assert.False(t, mc.IsSameModel("", "gpt-4o"))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_PrefixedAllowedModelInCatalog(t *testing.T) {
|
||||
mc := newTestCatalog(
|
||||
map[schemas.ModelProvider][]string{
|
||||
schemas.OpenRouter: {"openai/gpt-4o"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
providerConfig := configstore.ProviderConfig{}
|
||||
|
||||
assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", &providerConfig, []string{"openai/gpt-4o"}))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_CustomProviderListModelsDisabled(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
|
||||
// Custom provider with list-models disabled + ["*"] → should return true
|
||||
providerConfig := configstore.ProviderConfig{
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
AllowedRequests: &schemas.AllowedRequests{
|
||||
ListModels: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "any-model", &providerConfig, []string{"*"}))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_CustomProviderListModelsEnabled(t *testing.T) {
|
||||
mc := newTestCatalog(
|
||||
map[schemas.ModelProvider][]string{
|
||||
"custom-provider": {"model-a"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Custom provider with list-models enabled + ["*"] → should go through catalog
|
||||
providerConfig := configstore.ProviderConfig{
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
AllowedRequests: &schemas.AllowedRequests{
|
||||
ListModels: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
// model-a is in catalog → allowed
|
||||
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "model-a", &providerConfig, []string{"*"}))
|
||||
// model-b is NOT in catalog → denied
|
||||
assert.False(t, mc.IsModelAllowedForProvider("custom-provider", "model-b", &providerConfig, []string{"*"}))
|
||||
}
|
||||
|
||||
func TestIsModelAllowedForProvider_NilProviderConfig(t *testing.T) {
|
||||
mc := newTestCatalog(
|
||||
map[schemas.ModelProvider][]string{
|
||||
"some-provider": {"model-x"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// nil providerConfig + ["*"] → should go through catalog (not bypass)
|
||||
assert.True(t, mc.IsModelAllowedForProvider("some-provider", "model-x", nil, []string{"*"}))
|
||||
assert.False(t, mc.IsModelAllowedForProvider("some-provider", "model-y", nil, []string{"*"}))
|
||||
}
|
||||
639
framework/modelcatalog/models.go
Normal file
639
framework/modelcatalog/models.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair.
|
||||
// It prefers chat, then responses, then text-completion entries; if none exist,
|
||||
// it falls back to the lexicographically first available mode for deterministic behavior.
|
||||
func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
|
||||
baseModel := mc.getBaseModelNameUnsafe(model)
|
||||
if baseModel != model {
|
||||
if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil {
|
||||
return entry
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelsForProvider returns all available models for a given provider (thread-safe)
|
||||
func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
models, exists := mc.modelPool[provider]
|
||||
if !exists {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe)
|
||||
func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
models, exists := mc.unfilteredModelPool[provider]
|
||||
if !exists {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe).
|
||||
// This is used for governance model selection when no specific provider is chosen.
|
||||
func (mc *ModelCatalog) GetDistinctBaseModelNames() []string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
seen := make(map[string]bool)
|
||||
for _, baseName := range mc.baseModelIndex {
|
||||
seen[baseName] = true
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(seen))
|
||||
for name := range seen {
|
||||
result = append(result, name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetProvidersForModel returns all providers for a given model (thread-safe)
|
||||
func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
providers := make([]schemas.ModelProvider, 0)
|
||||
for provider, models := range mc.modelPool {
|
||||
isModelMatch := false
|
||||
for _, m := range models {
|
||||
if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) {
|
||||
isModelMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isModelMatch {
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Handler special provider cases
|
||||
// 1. Handler openrouter models
|
||||
if !slices.Contains(providers, schemas.OpenRouter) {
|
||||
for _, provider := range providers {
|
||||
if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok {
|
||||
if slices.Contains(openRouterModels, string(provider)+"/"+model) {
|
||||
providers = append(providers, schemas.OpenRouter)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Handle vertex models
|
||||
if !slices.Contains(providers, schemas.Vertex) {
|
||||
for _, provider := range providers {
|
||||
if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok {
|
||||
if slices.Contains(vertexModels, string(provider)+"/"+model) {
|
||||
providers = append(providers, schemas.Vertex)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle openai models for groq
|
||||
if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") {
|
||||
if groqModels, ok := mc.modelPool[schemas.Groq]; ok {
|
||||
if slices.Contains(groqModels, "openai/"+model) {
|
||||
providers = append(providers, schemas.Groq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Handle anthropic models for bedrock
|
||||
if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") {
|
||||
if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok {
|
||||
for _, bedrockModel := range bedrockModels {
|
||||
if strings.Contains(bedrockModel, model) {
|
||||
providers = append(providers, schemas.Bedrock)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
// IsModelAllowedForProvider checks if a model is allowed for a specific provider
|
||||
// based on the allowed models list and catalog data. It handles all cross-provider
|
||||
// logic including provider-prefixed models and special routing rules.
|
||||
//
|
||||
// Parameters:
|
||||
// - provider: The provider to check against
|
||||
// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet")
|
||||
// - allowedModels: List of allowed model names (can be empty, can include provider prefixes)
|
||||
//
|
||||
// Behavior:
|
||||
// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model
|
||||
// (delegates to GetProvidersForModel which handles all cross-provider logic)
|
||||
// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair
|
||||
// - If allowedModels is not empty: Checks if model matches any entry in the list
|
||||
// Provider-specific validation:
|
||||
// - Direct matches: "gpt-4o" in allowedModels for any provider
|
||||
// - Prefixed matches: Only if the prefixed model exists in provider's catalog
|
||||
// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog
|
||||
// contains "openai/gpt-4o" AND the model part matches the request)
|
||||
//
|
||||
// Returns:
|
||||
// - bool: true if the model is allowed for the provider, false otherwise
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// // Wildcard allowedModels - uses catalog to check provider support
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"})
|
||||
// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet")
|
||||
//
|
||||
// // Empty allowedModels - deny all (deny-by-default)
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{})
|
||||
// // Returns: false (no models are permitted)
|
||||
//
|
||||
// // Explicit allowedModels with prefix - validates against catalog
|
||||
// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"})
|
||||
// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o")
|
||||
//
|
||||
// // Explicit allowedModels with prefix - wrong model
|
||||
// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"})
|
||||
// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet")
|
||||
//
|
||||
// // Explicit allowedModels without prefix
|
||||
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
|
||||
// // Returns: true (direct match)
|
||||
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels schemas.WhiteList) bool {
|
||||
isCustomProvider := false
|
||||
hasListModelsEndpointDisabled := false
|
||||
if providerConfig != nil {
|
||||
isCustomProvider = providerConfig.CustomProviderConfig != nil
|
||||
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
|
||||
}
|
||||
|
||||
// Case 1: ["*"] = allow all models; use catalog to determine support
|
||||
// Empty allowedModels = deny all (fail-safe deny-by-default)
|
||||
if allowedModels.IsUnrestricted() {
|
||||
if isCustomProvider && hasListModelsEndpointDisabled {
|
||||
return true
|
||||
}
|
||||
supportedProviders := mc.GetProvidersForModel(model)
|
||||
return slices.Contains(supportedProviders, provider)
|
||||
}
|
||||
if allowedModels.IsEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Case 2: Explicit allowedModels = check if model matches any entry
|
||||
// Get provider's catalog models for validation of prefixed entries
|
||||
providerCatalogModels := mc.GetModelsForProvider(provider)
|
||||
|
||||
for _, allowedModel := range allowedModels {
|
||||
// Direct match: "gpt-4o" == "gpt-4o"
|
||||
if allowedModel == model {
|
||||
return true
|
||||
}
|
||||
|
||||
// Provider-prefixed match: verify it exists in provider's catalog first
|
||||
// This ensures we only allow provider-specific model combinations that are actually supported
|
||||
if strings.Contains(allowedModel, "/") {
|
||||
// Check if this exact prefixed model exists in the provider's catalog
|
||||
// e.g., for openrouter, check if "openai/gpt-4o" is in its catalog
|
||||
if slices.Contains(providerCatalogModels, allowedModel) {
|
||||
// Extract the model part and compare with request
|
||||
_, modelPart := schemas.ParseModelString(allowedModel, "")
|
||||
if modelPart == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetBaseModelName returns the canonical base model name for a given model string.
|
||||
// It uses the pre-computed base_model from the pricing catalog when available,
|
||||
// falling back to algorithmic date/version stripping for models not in the catalog.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o"
|
||||
// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o"
|
||||
// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback)
|
||||
func (mc *ModelCatalog) GetBaseModelName(model string) string {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
return mc.getBaseModelNameUnsafe(model)
|
||||
}
|
||||
|
||||
// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking.
|
||||
// This is used to avoid locking overhead when getting the base model name for many models.
|
||||
// Make sure the caller function is holding the read lock before calling this function.
|
||||
// It is not safe to use this function when the model pool is being updated.
|
||||
func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string {
|
||||
// Step 1: Direct lookup in base model index
|
||||
if base, ok := mc.baseModelIndex[model]; ok {
|
||||
return base
|
||||
}
|
||||
|
||||
// Step 2: Strip provider prefix and try again
|
||||
_, baseName := schemas.ParseModelString(model, "")
|
||||
if baseName != model {
|
||||
if base, ok := mc.baseModelIndex[baseName]; ok {
|
||||
return base
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Fallback to algorithmic date/version stripping
|
||||
// (for models not in the catalog, e.g., user-configured custom models)
|
||||
return schemas.BaseModelName(baseName)
|
||||
}
|
||||
|
||||
// IsSameModel checks if two model strings refer to the same underlying model.
|
||||
// It compares the canonical base model names derived from the pricing catalog
|
||||
// (or algorithmic fallback for models not in the catalog).
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match)
|
||||
// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model)
|
||||
// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models)
|
||||
// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false
|
||||
func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool {
|
||||
if model1 == model2 {
|
||||
return true
|
||||
}
|
||||
return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2)
|
||||
}
|
||||
|
||||
// DeleteModelDataForProvider deletes all model data from the pool for a given provider
|
||||
func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
delete(mc.modelPool, provider)
|
||||
delete(mc.unfilteredModelPool, provider)
|
||||
}
|
||||
|
||||
// UpsertModelDataForProvider upserts model data for a given provider
|
||||
func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) {
|
||||
if modelData == nil {
|
||||
return
|
||||
}
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Populating models from pricing data for the given provider
|
||||
// Provider models map
|
||||
providerModels := []string{}
|
||||
// Iterate through all pricing data to collect models per provider
|
||||
for _, pricing := range mc.pricingData {
|
||||
// Normalize provider before adding to model pool
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
// We will only add models for the given provider
|
||||
if normalizedProvider != provider {
|
||||
continue
|
||||
}
|
||||
// Add model to the provider's model set (using map for deduplication)
|
||||
if slices.Contains(providerModels, pricing.Model) {
|
||||
continue
|
||||
}
|
||||
providerModels = append(providerModels, pricing.Model)
|
||||
// Build base model index from pre-computed base_model field
|
||||
if pricing.BaseModel != "" {
|
||||
mc.baseModelIndex[pricing.Model] = pricing.BaseModel
|
||||
}
|
||||
}
|
||||
// If modelData is empty, then we allow all models
|
||||
if len(modelData.Data) == 0 && len(allowedModels) == 0 {
|
||||
mc.modelPool[provider] = providerModels
|
||||
return
|
||||
}
|
||||
// Here we make sure that we still keep the backup for model catalog intact
|
||||
// So we start with a existing model pool and add the new models from incoming data
|
||||
finalModelList := make([]string, 0)
|
||||
seenModels := make(map[string]bool)
|
||||
// Case where list models failed but we have allowed models from keys
|
||||
if len(modelData.Data) == 0 && len(allowedModels) > 0 {
|
||||
for _, allowedModel := range allowedModels {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
finalModelList = append(finalModelList, parsedModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, model := range modelData.Data {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
finalModelList = append(finalModelList, parsedModel)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedModels) == 0 {
|
||||
for _, model := range providerModels {
|
||||
if !seenModels[model] {
|
||||
seenModels[model] = true
|
||||
finalModelList = append(finalModelList, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
mc.modelPool[provider] = finalModelList
|
||||
}
|
||||
|
||||
// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider
|
||||
func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) {
|
||||
if modelData == nil {
|
||||
return
|
||||
}
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Populating models from pricing data for the given provider
|
||||
providerModels := []string{}
|
||||
seenModels := make(map[string]bool)
|
||||
for _, pricing := range mc.pricingData {
|
||||
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
|
||||
if normalizedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[pricing.Model] {
|
||||
seenModels[pricing.Model] = true
|
||||
providerModels = append(providerModels, pricing.Model)
|
||||
}
|
||||
}
|
||||
for _, model := range modelData.Data {
|
||||
parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "")
|
||||
if parsedProvider != provider {
|
||||
continue
|
||||
}
|
||||
if !seenModels[parsedModel] {
|
||||
seenModels[parsedModel] = true
|
||||
providerModels = append(providerModels, parsedModel)
|
||||
}
|
||||
}
|
||||
mc.unfilteredModelPool[provider] = providerModels
|
||||
}
|
||||
|
||||
// RefineModelForProvider refines the model for a given provider by performing a lookup
|
||||
// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts.
|
||||
// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b"
|
||||
//
|
||||
// Behavior:
|
||||
// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error
|
||||
// - When exactly one match is found, returns the fully-qualified model (provider/model format)
|
||||
// - When the provider is not handled or no refinement is needed, returns the original model unchanged
|
||||
func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) {
|
||||
switch provider {
|
||||
case schemas.Groq:
|
||||
if strings.Contains(model, "gpt-") {
|
||||
return "openai/" + model, nil
|
||||
}
|
||||
return mc.refineNestedProviderModel(provider, model)
|
||||
case schemas.Replicate:
|
||||
return mc.refineNestedProviderModel(provider, model)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// SetPricingOverrides replaces the full in-memory pricing override set.
|
||||
func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error {
|
||||
seen := make(map[string]int, len(rows))
|
||||
overrides := make([]PricingOverride, 0, len(rows))
|
||||
for i := range rows {
|
||||
o, err := convertTablePricingOverrideToPricingOverride(&rows[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idx, exists := seen[o.ID]; exists {
|
||||
overrides[idx] = o // last entry wins for duplicate IDs
|
||||
} else {
|
||||
seen[o.ID] = len(overrides)
|
||||
overrides = append(overrides, o)
|
||||
}
|
||||
}
|
||||
mc.overridesMu.Lock()
|
||||
mc.rawOverrides = overrides
|
||||
mc.customPricing = buildCustomPricingData(overrides)
|
||||
mc.overridesMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single
|
||||
// operation, rebuilding the lookup map only once at the end.
|
||||
func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error {
|
||||
// Deduplicate the input batch by ID (last entry wins) and build the
|
||||
// incoming set for O(1) lookup when filtering existing rawOverrides.
|
||||
seenIncoming := make(map[string]int, len(rows))
|
||||
overrides := make([]PricingOverride, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
o, err := convertTablePricingOverrideToPricingOverride(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if idx, exists := seenIncoming[o.ID]; exists {
|
||||
overrides[idx] = o // last entry wins for duplicate IDs
|
||||
} else {
|
||||
seenIncoming[o.ID] = len(overrides)
|
||||
overrides = append(overrides, o)
|
||||
}
|
||||
}
|
||||
|
||||
mc.overridesMu.Lock()
|
||||
defer mc.overridesMu.Unlock()
|
||||
|
||||
updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides))
|
||||
for _, o := range mc.rawOverrides {
|
||||
if _, replacing := seenIncoming[o.ID]; !replacing {
|
||||
updated = append(updated, o)
|
||||
}
|
||||
}
|
||||
updated = append(updated, overrides...)
|
||||
mc.rawOverrides = updated
|
||||
mc.customPricing = buildCustomPricingData(updated)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePricingOverride removes a pricing override by ID.
|
||||
func (mc *ModelCatalog) DeletePricingOverride(id string) {
|
||||
mc.overridesMu.Lock()
|
||||
defer mc.overridesMu.Unlock()
|
||||
|
||||
updated := make([]PricingOverride, 0, len(mc.rawOverrides))
|
||||
for _, o := range mc.rawOverrides {
|
||||
if o.ID != id {
|
||||
updated = append(updated, o)
|
||||
}
|
||||
}
|
||||
mc.rawOverrides = updated
|
||||
mc.customPricing = buildCustomPricingData(updated)
|
||||
}
|
||||
|
||||
// IsTextCompletionSupported checks if a model supports text completion for the given provider.
|
||||
// Returns true if the model has pricing data for text completion ("text_completion"),
|
||||
// false otherwise. This is used by the litellmcompat plugin to determine whether to
|
||||
// convert text completion requests to chat completion requests.
|
||||
func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
// Check for text completion mode in pricing data
|
||||
key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest))
|
||||
_, ok := mc.pricingData[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
|
||||
func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry {
|
||||
preferredModes := []schemas.RequestType{
|
||||
schemas.ChatCompletionRequest,
|
||||
schemas.ResponsesRequest,
|
||||
schemas.TextCompletionRequest,
|
||||
}
|
||||
|
||||
for _, mode := range preferredModes {
|
||||
key := makeKey(model, string(provider), normalizeRequestType(mode))
|
||||
pricing, ok := mc.pricingData[key]
|
||||
if ok {
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
}
|
||||
|
||||
prefix := model + "|" + string(provider) + "|"
|
||||
matchingKeys := make([]string, 0)
|
||||
for key := range mc.pricingData {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
matchingKeys = append(matchingKeys, key)
|
||||
}
|
||||
}
|
||||
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry {
|
||||
if baseModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
matchingKeys := make([]string, 0)
|
||||
for key, pricing := range mc.pricingData {
|
||||
if normalizeProvider(pricing.Provider) != string(provider) {
|
||||
continue
|
||||
}
|
||||
if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel {
|
||||
continue
|
||||
}
|
||||
matchingKeys = append(matchingKeys, key)
|
||||
}
|
||||
return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys)
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry {
|
||||
if len(matchingKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
preferredModes := []string{
|
||||
normalizeRequestType(schemas.ChatCompletionRequest),
|
||||
normalizeRequestType(schemas.ResponsesRequest),
|
||||
normalizeRequestType(schemas.TextCompletionRequest),
|
||||
}
|
||||
|
||||
for _, mode := range preferredModes {
|
||||
modeMatches := make([]string, 0)
|
||||
for _, key := range matchingKeys {
|
||||
parts := strings.SplitN(key, "|", 3)
|
||||
if len(parts) != 3 || parts[2] != mode {
|
||||
continue
|
||||
}
|
||||
modeMatches = append(modeMatches, key)
|
||||
}
|
||||
if len(modeMatches) == 0 {
|
||||
continue
|
||||
}
|
||||
slices.Sort(modeMatches)
|
||||
pricing := mc.pricingData[modeMatches[0]]
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
|
||||
slices.Sort(matchingKeys)
|
||||
pricing := mc.pricingData[matchingKeys[0]]
|
||||
return convertTableModelPricingToPricingData(&pricing)
|
||||
}
|
||||
|
||||
// refineNestedProviderModel resolves provider-native model slugs such as
|
||||
// "openai/gpt-5-nano" from a base model request like "gpt-5-nano".
|
||||
// It only considers catalog entries whose leading segment is a known Bifrost provider,
|
||||
// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched.
|
||||
func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) {
|
||||
mc.mu.RLock()
|
||||
models, ok := mc.modelPool[provider]
|
||||
mc.mu.RUnlock()
|
||||
if !ok {
|
||||
return model, nil
|
||||
}
|
||||
|
||||
candidateModels := make([]string, 0)
|
||||
seenCandidates := make(map[string]struct{})
|
||||
for _, poolModel := range models {
|
||||
providerPart, modelPart := schemas.ParseModelString(poolModel, "")
|
||||
if providerPart == "" || model != modelPart {
|
||||
continue
|
||||
}
|
||||
|
||||
candidate := string(providerPart) + "/" + modelPart
|
||||
if _, seen := seenCandidates[candidate]; seen {
|
||||
continue
|
||||
}
|
||||
seenCandidates[candidate] = struct{}{}
|
||||
candidateModels = append(candidateModels, candidate)
|
||||
}
|
||||
|
||||
switch len(candidateModels) {
|
||||
case 0:
|
||||
return model, nil
|
||||
case 1:
|
||||
return candidateModels[0], nil
|
||||
default:
|
||||
return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels)
|
||||
}
|
||||
}
|
||||
1205
framework/modelcatalog/pricing.go
Normal file
1205
framework/modelcatalog/pricing.go
Normal file
File diff suppressed because it is too large
Load Diff
470
framework/modelcatalog/pricing_overrides.go
Normal file
470
framework/modelcatalog/pricing_overrides.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// PricingLookupScopes carries the runtime identifiers used to resolve scoped
|
||||
// pricing overrides during cost calculation.
|
||||
type PricingLookupScopes struct {
|
||||
VirtualKeyID string
|
||||
SelectedKeyID string
|
||||
Provider string
|
||||
}
|
||||
|
||||
// PricingLookupScopesFromContext builds a PricingLookupScopes from a BifrostContext.
|
||||
// It reads the governance virtual key ID (not the raw VK token) and the selected key ID.
|
||||
// provider should be the provider name string (e.g. "openai"), pass "" if unavailable.
|
||||
// Returns nil only when ctx is nil. An empty scopes value is still returned when all fields
|
||||
// are empty so that global-scope overrides are always evaluated.
|
||||
// DO NOT USE THIS FUNCTION IN A GO ROUTINE. This is because it reads from ctx which is cancelled when the request ends.
|
||||
// Better to call it in PostHooks synchronously and then pass the scopes object to the pricing manager.
|
||||
// Only use this in go routines when you know for sure that the request will not end before the go routine completes.
|
||||
func PricingLookupScopesFromContext(ctx *schemas.BifrostContext, provider string) *PricingLookupScopes {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
|
||||
selectedKeyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string)
|
||||
return &PricingLookupScopes{
|
||||
VirtualKeyID: virtualKeyID,
|
||||
SelectedKeyID: selectedKeyID,
|
||||
Provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// ScopeKind identifies which governance scope an override applies to.
|
||||
type ScopeKind string
|
||||
|
||||
const (
|
||||
ScopeKindGlobal ScopeKind = "global"
|
||||
ScopeKindProvider ScopeKind = "provider"
|
||||
ScopeKindProviderKey ScopeKind = "provider_key"
|
||||
ScopeKindVirtualKey ScopeKind = "virtual_key"
|
||||
ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider"
|
||||
ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key"
|
||||
)
|
||||
|
||||
// MatchType controls how an override pattern is matched against model names.
|
||||
type MatchType string
|
||||
|
||||
const (
|
||||
MatchTypeExact MatchType = "exact"
|
||||
MatchTypeWildcard MatchType = "wildcard"
|
||||
)
|
||||
|
||||
// PricingOverride describes a scoped pricing override shared across config storage,
|
||||
// model catalog compilation, and governance APIs.
|
||||
type PricingOverride struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ScopeKind ScopeKind `json:"scope_kind"`
|
||||
VirtualKeyID *string `json:"virtual_key_id,omitempty"`
|
||||
ProviderID *string `json:"provider_id,omitempty"`
|
||||
ProviderKeyID *string `json:"provider_key_id,omitempty"`
|
||||
MatchType MatchType `json:"match_type"`
|
||||
Pattern string `json:"pattern"`
|
||||
RequestTypes []schemas.RequestType `json:"request_types,omitempty"`
|
||||
Options PricingOptions `json:"options"`
|
||||
}
|
||||
|
||||
// customPricingEntry is a single flattened override ready for lookup.
|
||||
type customPricingEntry struct {
|
||||
id string
|
||||
scopeKind ScopeKind
|
||||
virtualKeyID string
|
||||
providerID string
|
||||
providerKeyID string
|
||||
pattern string // exact model name, or wildcard prefix (trailing * stripped)
|
||||
wildcard bool
|
||||
requestModes map[string]struct{} // always non-nil for valid overrides
|
||||
options PricingOptions
|
||||
}
|
||||
|
||||
// customPricingData is the in-memory lookup structure for pricing overrides.
|
||||
// Exact matches are indexed by model name; wildcards are a flat slice.
|
||||
type customPricingData struct {
|
||||
exact map[string][]customPricingEntry
|
||||
wildcard []customPricingEntry
|
||||
}
|
||||
|
||||
// IsValid validates the shared pricing override contract before persistence or runtime use.
|
||||
//
|
||||
// Input: override — the PricingOverride to validate (receiver).
|
||||
// Output: error — non-nil if any scope, pattern, or request-type constraint is violated.
|
||||
func (override *PricingOverride) IsValid() error {
|
||||
if err := override.validateScopeKind(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := override.validatePattern(); err != nil {
|
||||
return err
|
||||
}
|
||||
return override.validateRequestTypes()
|
||||
}
|
||||
|
||||
// validateScopeKind validates the scope identifiers required by override.ScopeKind.
|
||||
//
|
||||
// Input: override — receiver; ScopeKind and the three optional ID fields are inspected.
|
||||
// Output: error — non-nil when required identifiers are absent or forbidden ones are present.
|
||||
func (override *PricingOverride) validateScopeKind() error {
|
||||
switch override.ScopeKind {
|
||||
case ScopeKindGlobal:
|
||||
if override.VirtualKeyID != nil || override.ProviderID != nil || override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("global scope_kind must not include scope identifiers")
|
||||
}
|
||||
case ScopeKindProvider:
|
||||
if override.ProviderID == nil {
|
||||
return fmt.Errorf("provider_id is required for provider scope_kind")
|
||||
}
|
||||
if override.VirtualKeyID != nil || override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("provider scope_kind only supports provider_id")
|
||||
}
|
||||
case ScopeKindProviderKey:
|
||||
if override.ProviderKeyID == nil {
|
||||
return fmt.Errorf("provider_key_id is required for provider_key scope_kind")
|
||||
}
|
||||
if override.VirtualKeyID != nil || override.ProviderID != nil {
|
||||
return fmt.Errorf("provider_key scope_kind only supports provider_key_id")
|
||||
}
|
||||
case ScopeKindVirtualKey:
|
||||
if override.VirtualKeyID == nil {
|
||||
return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind")
|
||||
}
|
||||
if override.ProviderID != nil || override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id")
|
||||
}
|
||||
case ScopeKindVirtualKeyProvider:
|
||||
if override.VirtualKeyID == nil || override.ProviderID == nil {
|
||||
return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind")
|
||||
}
|
||||
if override.ProviderKeyID != nil {
|
||||
return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id")
|
||||
}
|
||||
case ScopeKindVirtualKeyProviderKey:
|
||||
if override.VirtualKeyID == nil || override.ProviderID == nil || override.ProviderKeyID == nil {
|
||||
return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported scope_kind %q", override.ScopeKind)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePattern checks that Pattern is non-empty and consistent with MatchType.
|
||||
//
|
||||
// Input: override — receiver; Pattern and MatchType are inspected.
|
||||
// Output: error — non-nil when the pattern is empty, contains a wildcard for exact mode,
|
||||
//
|
||||
// or does not end with a single trailing "*" for wildcard mode.
|
||||
func (override *PricingOverride) validatePattern() error {
|
||||
pattern := strings.TrimSpace(override.Pattern)
|
||||
if pattern == "" {
|
||||
return fmt.Errorf("pattern is required")
|
||||
}
|
||||
switch override.MatchType {
|
||||
case MatchTypeExact:
|
||||
if strings.Contains(pattern, "*") {
|
||||
return fmt.Errorf("exact match pattern must not contain wildcards")
|
||||
}
|
||||
case MatchTypeWildcard:
|
||||
if !strings.HasSuffix(pattern, "*") {
|
||||
return fmt.Errorf("wildcard pattern must end with *")
|
||||
}
|
||||
if strings.Count(pattern, "*") != 1 {
|
||||
return fmt.Errorf("wildcard pattern must contain exactly one trailing *")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported match_type %q", override.MatchType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRequestTypes checks that RequestTypes is non-empty and that every entry is a
|
||||
// supported base request type. Stream variants (e.g. chat_completion_stream) are rejected —
|
||||
// the base type (chat_completion) already covers both streaming and non-streaming requests.
|
||||
//
|
||||
// Input: override — receiver; RequestTypes slice is inspected.
|
||||
// Output: error — non-nil if RequestTypes is empty, or contains an unsupported or stream variant.
|
||||
func (override *PricingOverride) validateRequestTypes() error {
|
||||
if len(override.RequestTypes) == 0 {
|
||||
return fmt.Errorf("request_types is required and must contain at least one value")
|
||||
}
|
||||
for _, rt := range override.RequestTypes {
|
||||
if normalizeStreamRequestType(rt) != rt {
|
||||
return fmt.Errorf("unsupported request_type %q: use the base type (e.g. %q covers both streaming and non-streaming)", rt, normalizeStreamRequestType(rt))
|
||||
}
|
||||
if normalizeRequestType(rt) == "unknown" {
|
||||
return fmt.Errorf("unsupported request_type %q", rt)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchesScope reports whether the entry's governance scope matches the runtime identifiers.
|
||||
//
|
||||
// Input: scopes — runtime VirtualKeyID, SelectedKeyID, and Provider to match against.
|
||||
// Output: bool — true when the entry's scope kind and stored IDs align with scopes.
|
||||
func (e *customPricingEntry) matchesScope(scopes PricingLookupScopes) bool {
|
||||
switch e.scopeKind {
|
||||
case ScopeKindGlobal:
|
||||
return true
|
||||
case ScopeKindProvider:
|
||||
return e.providerID == scopes.Provider
|
||||
case ScopeKindProviderKey:
|
||||
return e.providerKeyID == scopes.SelectedKeyID
|
||||
case ScopeKindVirtualKey:
|
||||
return e.virtualKeyID == scopes.VirtualKeyID
|
||||
case ScopeKindVirtualKeyProvider:
|
||||
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider
|
||||
case ScopeKindVirtualKeyProviderKey:
|
||||
return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider && e.providerKeyID == scopes.SelectedKeyID
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesMode reports whether the entry applies to the given normalized request mode.
|
||||
//
|
||||
// Input: mode — normalized request type string (e.g. "chat", "embedding").
|
||||
// Output: bool — true when requestModes contains mode.
|
||||
func (e *customPricingEntry) matchesMode(mode string) bool {
|
||||
_, ok := e.requestModes[mode]
|
||||
return ok
|
||||
}
|
||||
|
||||
// resolve walks the 6-scope priority hierarchy and returns the first matching
|
||||
// pricing patch for the given model, request mode, and runtime scopes.
|
||||
//
|
||||
// Input: model — exact model name being priced.
|
||||
//
|
||||
// mode — normalized request type string (e.g. "chat", "embedding").
|
||||
// scopes — runtime governance identifiers used to narrow the scope search.
|
||||
//
|
||||
// Output: *PricingOptions — pointer to the first matching override's options, or nil if none match.
|
||||
func (c *customPricingData) resolve(model, mode string, scopes PricingLookupScopes) *PricingOptions {
|
||||
for _, scopeKind := range scopePriorityOrder(scopes) {
|
||||
for i := range c.exact[model] {
|
||||
e := &c.exact[model][i]
|
||||
if e.scopeKind == scopeKind && e.matchesScope(scopes) && e.matchesMode(mode) {
|
||||
return &e.options
|
||||
}
|
||||
}
|
||||
for i := range c.wildcard {
|
||||
e := &c.wildcard[i]
|
||||
if e.scopeKind == scopeKind && e.matchesScope(scopes) && strings.HasPrefix(model, e.pattern) && e.matchesMode(mode) {
|
||||
return &e.options
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scopePriorityOrder returns scope kinds in most-specific-first order,
|
||||
// skipping scopes that can't match given the available runtime identifiers.
|
||||
//
|
||||
// Input: scopes — runtime governance identifiers; empty fields cause the corresponding scope kinds to be omitted.
|
||||
// Output: []ScopeKind — ordered list from most-specific (VirtualKeyProviderKey) to least-specific (Global).
|
||||
func scopePriorityOrder(scopes PricingLookupScopes) []ScopeKind {
|
||||
order := make([]ScopeKind, 0, 6)
|
||||
if scopes.VirtualKeyID != "" && scopes.Provider != "" && scopes.SelectedKeyID != "" {
|
||||
order = append(order, ScopeKindVirtualKeyProviderKey)
|
||||
}
|
||||
if scopes.VirtualKeyID != "" && scopes.Provider != "" {
|
||||
order = append(order, ScopeKindVirtualKeyProvider)
|
||||
}
|
||||
if scopes.VirtualKeyID != "" {
|
||||
order = append(order, ScopeKindVirtualKey)
|
||||
}
|
||||
if scopes.SelectedKeyID != "" {
|
||||
order = append(order, ScopeKindProviderKey)
|
||||
}
|
||||
if scopes.Provider != "" {
|
||||
order = append(order, ScopeKindProvider)
|
||||
}
|
||||
order = append(order, ScopeKindGlobal)
|
||||
return order
|
||||
}
|
||||
|
||||
// buildCustomPricingData constructs a customPricingData lookup structure from a raw override slice.
|
||||
//
|
||||
// Input: overrides — slice of validated PricingOverride records loaded from the config store.
|
||||
// Output: *customPricingData — ready-to-query structure with exact and wildcard indexes populated.
|
||||
func buildCustomPricingData(overrides []PricingOverride) *customPricingData {
|
||||
data := &customPricingData{
|
||||
exact: make(map[string][]customPricingEntry, len(overrides)),
|
||||
}
|
||||
for _, o := range overrides {
|
||||
entry := customPricingEntry{
|
||||
id: o.ID,
|
||||
scopeKind: o.ScopeKind,
|
||||
options: o.Options,
|
||||
}
|
||||
if o.VirtualKeyID != nil {
|
||||
entry.virtualKeyID = *o.VirtualKeyID
|
||||
}
|
||||
if o.ProviderID != nil {
|
||||
entry.providerID = *o.ProviderID
|
||||
}
|
||||
if o.ProviderKeyID != nil {
|
||||
entry.providerKeyID = *o.ProviderKeyID
|
||||
}
|
||||
entry.requestModes = make(map[string]struct{}, len(o.RequestTypes))
|
||||
for _, rt := range o.RequestTypes {
|
||||
entry.requestModes[normalizeRequestType(rt)] = struct{}{}
|
||||
}
|
||||
pattern := strings.TrimSpace(o.Pattern)
|
||||
switch o.MatchType {
|
||||
case MatchTypeExact:
|
||||
entry.pattern = pattern
|
||||
data.exact[pattern] = append(data.exact[pattern], entry)
|
||||
case MatchTypeWildcard:
|
||||
entry.pattern = strings.TrimSuffix(pattern, "*")
|
||||
entry.wildcard = true
|
||||
data.wildcard = append(data.wildcard, entry)
|
||||
}
|
||||
}
|
||||
// Sort wildcards by descending prefix length so more-specific patterns (e.g. "gpt-4*")
|
||||
// are checked before broader ones (e.g. "gpt-*"), making precedence deterministic.
|
||||
sort.Slice(data.wildcard, func(i, j int) bool {
|
||||
return len(data.wildcard[i].pattern) > len(data.wildcard[j].pattern)
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
// applyPricingOverrides resolves any active scoped pricing override for the given model
|
||||
// and request type, then patches the catalog base pricing with the override values.
|
||||
// It returns the original pricing unchanged when no custom pricing tree is loaded or
|
||||
// when the request type cannot be mapped to a known pricing mode.
|
||||
//
|
||||
// Input: model — exact model name being priced.
|
||||
//
|
||||
// requestType — the request type used to derive the pricing mode.
|
||||
// pricing — base pricing row from the catalog to patch.
|
||||
// scopes — runtime governance identifiers used to narrow the override scope.
|
||||
//
|
||||
// Output: TableModelPricing — patched pricing row, or pricing unchanged if no override matches.
|
||||
// bool — true when an override was applied, false otherwise.
|
||||
func (mc *ModelCatalog) applyPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) (configstoreTables.TableModelPricing, bool) {
|
||||
mc.overridesMu.RLock()
|
||||
custom := mc.customPricing
|
||||
mc.overridesMu.RUnlock()
|
||||
|
||||
if custom == nil {
|
||||
return pricing, false
|
||||
}
|
||||
|
||||
mode := normalizeRequestType(requestType)
|
||||
if mode == "unknown" {
|
||||
return pricing, false
|
||||
}
|
||||
|
||||
if patch := custom.resolve(model, mode, scopes); patch != nil {
|
||||
return patchPricing(pricing, *patch), true
|
||||
}
|
||||
return pricing, false
|
||||
}
|
||||
|
||||
// patchPricing applies override values onto a copy of the base pricing row.
|
||||
// For all fields, a non-nil override pointer replaces the corresponding destination value;
|
||||
// a nil override leaves the base value intact.
|
||||
// The original pricing row is never modified; a patched copy is always returned.
|
||||
//
|
||||
// Input: pricing — base pricing row from the catalog.
|
||||
//
|
||||
// override — pricing options sourced from the matched override entry.
|
||||
//
|
||||
// Output: TableModelPricing — shallow copy of pricing with override fields applied.
|
||||
func patchPricing(pricing configstoreTables.TableModelPricing, override PricingOptions) configstoreTables.TableModelPricing {
|
||||
patched := pricing
|
||||
|
||||
for _, field := range []struct {
|
||||
dst **float64
|
||||
src *float64
|
||||
}{
|
||||
{dst: &patched.InputCostPerToken, src: override.InputCostPerToken},
|
||||
{dst: &patched.OutputCostPerToken, src: override.OutputCostPerToken},
|
||||
{dst: &patched.InputCostPerTokenPriority, src: override.InputCostPerTokenPriority},
|
||||
{dst: &patched.OutputCostPerTokenPriority, src: override.OutputCostPerTokenPriority},
|
||||
{dst: &patched.InputCostPerTokenFlex, src: override.InputCostPerTokenFlex},
|
||||
{dst: &patched.OutputCostPerTokenFlex, src: override.OutputCostPerTokenFlex},
|
||||
{dst: &patched.InputCostPerVideoPerSecond, src: override.InputCostPerVideoPerSecond},
|
||||
{dst: &patched.OutputCostPerVideoPerSecond, src: override.OutputCostPerVideoPerSecond},
|
||||
{dst: &patched.OutputCostPerSecond, src: override.OutputCostPerSecond},
|
||||
{dst: &patched.InputCostPerAudioPerSecond, src: override.InputCostPerAudioPerSecond},
|
||||
{dst: &patched.InputCostPerSecond, src: override.InputCostPerSecond},
|
||||
{dst: &patched.InputCostPerAudioToken, src: override.InputCostPerAudioToken},
|
||||
{dst: &patched.OutputCostPerAudioToken, src: override.OutputCostPerAudioToken},
|
||||
{dst: &patched.InputCostPerCharacter, src: override.InputCostPerCharacter},
|
||||
{dst: &patched.InputCostPerTokenAbove128kTokens, src: override.InputCostPerTokenAbove128kTokens},
|
||||
{dst: &patched.InputCostPerImageAbove128kTokens, src: override.InputCostPerImageAbove128kTokens},
|
||||
{dst: &patched.InputCostPerVideoPerSecondAbove128kTokens, src: override.InputCostPerVideoPerSecondAbove128kTokens},
|
||||
{dst: &patched.InputCostPerAudioPerSecondAbove128kTokens, src: override.InputCostPerAudioPerSecondAbove128kTokens},
|
||||
{dst: &patched.OutputCostPerTokenAbove128kTokens, src: override.OutputCostPerTokenAbove128kTokens},
|
||||
{dst: &patched.InputCostPerTokenAbove200kTokens, src: override.InputCostPerTokenAbove200kTokens},
|
||||
{dst: &patched.InputCostPerTokenAbove200kTokensPriority, src: override.InputCostPerTokenAbove200kTokensPriority},
|
||||
{dst: &patched.OutputCostPerTokenAbove200kTokens, src: override.OutputCostPerTokenAbove200kTokens},
|
||||
{dst: &patched.OutputCostPerTokenAbove200kTokensPriority, src: override.OutputCostPerTokenAbove200kTokensPriority},
|
||||
{dst: &patched.InputCostPerTokenAbove272kTokens, src: override.InputCostPerTokenAbove272kTokens},
|
||||
{dst: &patched.InputCostPerTokenAbove272kTokensPriority, src: override.InputCostPerTokenAbove272kTokensPriority},
|
||||
{dst: &patched.OutputCostPerTokenAbove272kTokens, src: override.OutputCostPerTokenAbove272kTokens},
|
||||
{dst: &patched.OutputCostPerTokenAbove272kTokensPriority, src: override.OutputCostPerTokenAbove272kTokensPriority},
|
||||
{dst: &patched.CacheCreationInputTokenCostAbove200kTokens, src: override.CacheCreationInputTokenCostAbove200kTokens},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove200kTokens, src: override.CacheReadInputTokenCostAbove200kTokens},
|
||||
{dst: &patched.CacheReadInputTokenCost, src: override.CacheReadInputTokenCost},
|
||||
{dst: &patched.CacheCreationInputTokenCost, src: override.CacheCreationInputTokenCost},
|
||||
{dst: &patched.CacheCreationInputTokenCostAbove1hr, src: override.CacheCreationInputTokenCostAbove1hr},
|
||||
{dst: &patched.CacheCreationInputTokenCostAbove1hrAbove200kTokens, src: override.CacheCreationInputTokenCostAbove1hrAbove200kTokens},
|
||||
{dst: &patched.CacheCreationInputAudioTokenCost, src: override.CacheCreationInputAudioTokenCost},
|
||||
{dst: &patched.CacheReadInputTokenCostPriority, src: override.CacheReadInputTokenCostPriority},
|
||||
{dst: &patched.CacheReadInputTokenCostFlex, src: override.CacheReadInputTokenCostFlex},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove200kTokensPriority, src: override.CacheReadInputTokenCostAbove200kTokensPriority},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove272kTokens, src: override.CacheReadInputTokenCostAbove272kTokens},
|
||||
{dst: &patched.CacheReadInputTokenCostAbove272kTokensPriority, src: override.CacheReadInputTokenCostAbove272kTokensPriority},
|
||||
{dst: &patched.InputCostPerTokenBatches, src: override.InputCostPerTokenBatches},
|
||||
{dst: &patched.OutputCostPerTokenBatches, src: override.OutputCostPerTokenBatches},
|
||||
{dst: &patched.InputCostPerImageToken, src: override.InputCostPerImageToken},
|
||||
{dst: &patched.OutputCostPerImageToken, src: override.OutputCostPerImageToken},
|
||||
{dst: &patched.InputCostPerImage, src: override.InputCostPerImage},
|
||||
{dst: &patched.OutputCostPerImage, src: override.OutputCostPerImage},
|
||||
{dst: &patched.InputCostPerPixel, src: override.InputCostPerPixel},
|
||||
{dst: &patched.OutputCostPerPixel, src: override.OutputCostPerPixel},
|
||||
{dst: &patched.OutputCostPerImagePremiumImage, src: override.OutputCostPerImagePremiumImage},
|
||||
{dst: &patched.OutputCostPerImageAbove512x512Pixels, src: override.OutputCostPerImageAbove512x512Pixels},
|
||||
{dst: &patched.OutputCostPerImageAbove512x512PixelsPremium, src: override.OutputCostPerImageAbove512x512PixelsPremium},
|
||||
{dst: &patched.OutputCostPerImageAbove1024x1024Pixels, src: override.OutputCostPerImageAbove1024x1024Pixels},
|
||||
{dst: &patched.OutputCostPerImageAbove1024x1024PixelsPremium, src: override.OutputCostPerImageAbove1024x1024PixelsPremium},
|
||||
{dst: &patched.OutputCostPerImageAbove2048x2048Pixels, src: override.OutputCostPerImageAbove2048x2048Pixels},
|
||||
{dst: &patched.OutputCostPerImageAbove4096x4096Pixels, src: override.OutputCostPerImageAbove4096x4096Pixels},
|
||||
{dst: &patched.CacheReadInputImageTokenCost, src: override.CacheReadInputImageTokenCost},
|
||||
{dst: &patched.SearchContextCostPerQuery, src: override.SearchContextCostPerQuery},
|
||||
{dst: &patched.CodeInterpreterCostPerSession, src: override.CodeInterpreterCostPerSession},
|
||||
{dst: &patched.OutputCostPerImageLowQuality, src: override.OutputCostPerImageLowQuality},
|
||||
{dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality},
|
||||
{dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality},
|
||||
{dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality},
|
||||
{dst: &patched.OCRCostPerPage, src: override.OCRCostPerPage},
|
||||
{dst: &patched.AnnotationCostPerPage, src: override.AnnotationCostPerPage},
|
||||
} {
|
||||
if field.src != nil {
|
||||
*field.dst = field.src
|
||||
}
|
||||
}
|
||||
return patched
|
||||
}
|
||||
|
||||
func (mc *ModelCatalog) loadPricingOverridesFromStore(ctx context.Context) error {
|
||||
if mc.configStore == nil {
|
||||
return nil
|
||||
}
|
||||
rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilters{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.SetPricingOverrides(rows)
|
||||
}
|
||||
507
framework/modelcatalog/pricing_overrides_test.go
Normal file
507
framework/modelcatalog/pricing_overrides_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type noOpLogger struct{}
|
||||
|
||||
func (noOpLogger) Debug(string, ...any) {}
|
||||
func (noOpLogger) Info(string, ...any) {}
|
||||
func (noOpLogger) Warn(string, ...any) {}
|
||||
func (noOpLogger) Error(string, ...any) {}
|
||||
func (noOpLogger) Fatal(string, ...any) {}
|
||||
func (noOpLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noOpLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noOpLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-*",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":10}`,
|
||||
},
|
||||
{
|
||||
ID: "openai-override-1",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":20}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 20.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "responses")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "responses",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-generic",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
{
|
||||
ID: "openai-specific",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ResponsesRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":15}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 15.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "vertex",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
geminiProviderID := "gemini"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "gemini-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &geminiProviderID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":7}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 7.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_DeploymentLookupUsesResolvedModelForOverrideMatching(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("dep-gpt4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "dep-gpt4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "resolved-model-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "dep-gpt4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":7}`,
|
||||
},
|
||||
}))
|
||||
|
||||
// Override pattern matches the resolved model name ("dep-gpt4o"), not the
|
||||
// originally requested name ("gpt-4o"), because resolved model has priority.
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "dep-gpt4o", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 7.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_FallbackUsesRequestedProviderForScopeMatching(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "vertex",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
geminiProviderID := "gemini"
|
||||
vertexProviderID := "vertex"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "gemini-provider-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &geminiProviderID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":5}`,
|
||||
},
|
||||
{
|
||||
ID: "vertex-provider-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &vertexProviderID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 5.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("openai/gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "openai/gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":19}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "openai/gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 1.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_NoMatchingOverrideLeavesPricingUnchanged(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
baseCacheRead := 0.4
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
CacheReadInputTokenCost: &baseCacheRead,
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "claude-*",
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 1.0, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 2.0, pricing.OutputCostPerToken)
|
||||
require.NotNil(t, pricing.CacheReadInputTokenCost)
|
||||
assert.Equal(t, 0.4, *pricing.CacheReadInputTokenCost)
|
||||
}
|
||||
|
||||
func TestDeleteProviderPricingOverrides_StopsApplying(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-4o",
|
||||
PricingPatchJSON: `{"input_cost_per_token":11}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 11.0, pricing.InputCostPerToken)
|
||||
|
||||
require.NoError(t, mc.SetPricingOverrides(nil))
|
||||
|
||||
pricing = mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 1.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestGetPricing_WildcardSpecificityLongerLiteralWins(t *testing.T) {
|
||||
t.Skip()
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "openai-override-0",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-*",
|
||||
PricingPatchJSON: `{"input_cost_per_token":5}`,
|
||||
},
|
||||
{
|
||||
ID: "openai-override-1",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-4o*",
|
||||
PricingPatchJSON: `{"input_cost_per_token":6}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
assert.Equal(t, 6.0, pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
// TestGetPricing_FirstInsertionWinsOnTie verifies that when multiple wildcard overrides
|
||||
// match the same model and scope, the first one inserted takes precedence.
|
||||
func TestGetPricing_FirstInsertionWinsOnTie(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
providerID := "openai"
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "a-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-4o*",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":8}`,
|
||||
},
|
||||
{
|
||||
ID: "b-override",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerID,
|
||||
MatchType: string(MatchTypeWildcard),
|
||||
Pattern: "gpt-4o*",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":9}`,
|
||||
},
|
||||
}))
|
||||
|
||||
pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"})
|
||||
require.NotNil(t, pricing)
|
||||
require.NotNil(t, pricing.InputCostPerToken)
|
||||
assert.Equal(t, 8.0, *pricing.InputCostPerToken)
|
||||
}
|
||||
|
||||
func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) {
|
||||
t.Skip()
|
||||
baseCacheRead := 0.4
|
||||
baseInputImage := 0.7
|
||||
base := configstoreTables.TableModelPricing{
|
||||
Model: "gpt-4o",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
CacheReadInputTokenCost: &baseCacheRead,
|
||||
InputCostPerImage: &baseInputImage,
|
||||
}
|
||||
|
||||
cacheRead := 0.9
|
||||
patched := patchPricing(base, PricingOptions{
|
||||
InputCostPerToken: bifrost.Ptr(3.0),
|
||||
CacheReadInputTokenCost: &cacheRead,
|
||||
})
|
||||
|
||||
assert.Equal(t, 3.0, patched.InputCostPerToken)
|
||||
require.NotNil(t, patched.CacheReadInputTokenCost)
|
||||
assert.Equal(t, 0.9, *patched.CacheReadInputTokenCost)
|
||||
|
||||
assert.Equal(t, 2.0, patched.OutputCostPerToken)
|
||||
require.NotNil(t, patched.InputCostPerImage)
|
||||
assert.Equal(t, 0.7, *patched.InputCostPerImage)
|
||||
}
|
||||
|
||||
func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) {
|
||||
mc := newTestCatalog(nil, nil)
|
||||
mc.logger = noOpLogger{}
|
||||
|
||||
providerScopeID := "openai"
|
||||
providerKeyScopeID := "provider-key-1"
|
||||
virtualKeyScopeID := "virtual-key-1"
|
||||
|
||||
require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{
|
||||
{
|
||||
ID: "global",
|
||||
ScopeKind: string(ScopeKindGlobal),
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":2}`,
|
||||
},
|
||||
{
|
||||
ID: "provider",
|
||||
ScopeKind: string(ScopeKindProvider),
|
||||
ProviderID: &providerScopeID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":3}`,
|
||||
},
|
||||
{
|
||||
ID: "provider-key",
|
||||
ScopeKind: string(ScopeKindProviderKey),
|
||||
ProviderKeyID: &providerKeyScopeID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":4}`,
|
||||
},
|
||||
{
|
||||
ID: "virtual-key",
|
||||
ScopeKind: string(ScopeKindVirtualKey),
|
||||
VirtualKeyID: &virtualKeyScopeID,
|
||||
MatchType: string(MatchTypeExact),
|
||||
Pattern: "gpt-5-nano",
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
PricingPatchJSON: `{"input_cost_per_token":5}`,
|
||||
},
|
||||
}))
|
||||
|
||||
base := configstoreTables.TableModelPricing{
|
||||
Model: "gpt-5-nano",
|
||||
Provider: "openai",
|
||||
Mode: "chat",
|
||||
InputCostPerToken: bifrost.Ptr(1.0),
|
||||
OutputCostPerToken: bifrost.Ptr(2.0),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes PricingLookupScopes
|
||||
expected float64
|
||||
}{
|
||||
{
|
||||
name: "virtual key wins over provider key, provider and global",
|
||||
scopes: PricingLookupScopes{
|
||||
VirtualKeyID: virtualKeyScopeID,
|
||||
SelectedKeyID: providerKeyScopeID,
|
||||
Provider: providerScopeID,
|
||||
},
|
||||
expected: 5.0,
|
||||
},
|
||||
{
|
||||
name: "provider key wins over provider and global",
|
||||
scopes: PricingLookupScopes{
|
||||
SelectedKeyID: providerKeyScopeID,
|
||||
Provider: providerScopeID,
|
||||
},
|
||||
expected: 4.0,
|
||||
},
|
||||
{
|
||||
name: "provider wins over global",
|
||||
scopes: PricingLookupScopes{
|
||||
Provider: providerScopeID,
|
||||
},
|
||||
expected: 3.0,
|
||||
},
|
||||
{
|
||||
name: "global applies when no narrower scope is provided",
|
||||
scopes: PricingLookupScopes{},
|
||||
expected: 2.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
patched, applied := mc.applyPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes)
|
||||
require.True(t, applied)
|
||||
require.NotNil(t, patched.InputCostPerToken)
|
||||
assert.Equal(t, tc.expected, *patched.InputCostPerToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
2164
framework/modelcatalog/pricing_test.go
Normal file
2164
framework/modelcatalog/pricing_test.go
Normal file
File diff suppressed because it is too large
Load Diff
51
framework/modelcatalog/refine_test.go
Normal file
51
framework/modelcatalog/refine_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRefineModelForProvider_ReplicateRefinesOpenAIModel verifies that
|
||||
// Replicate can recover nested provider slugs for provider-pinned OpenAI-family models.
|
||||
func TestRefineModelForProvider_ReplicateRefinesOpenAIModel(t *testing.T) {
|
||||
mc := newTestCatalog(map[schemas.ModelProvider][]string{
|
||||
schemas.Replicate: {"openai/gpt-5-nano"},
|
||||
}, map[string]string{
|
||||
"openai/gpt-5-nano": "gpt-5-nano",
|
||||
})
|
||||
|
||||
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-5-nano", refined)
|
||||
}
|
||||
|
||||
// TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel verifies that
|
||||
// standard Replicate owner/model slugs are not mistaken for nested provider slugs.
|
||||
func TestRefineModelForProvider_ReplicatePreservesOwnerSlashModel(t *testing.T) {
|
||||
mc := newTestCatalog(map[schemas.ModelProvider][]string{
|
||||
schemas.Replicate: {"meta/meta-llama-3-8b"},
|
||||
}, nil)
|
||||
|
||||
refined, err := mc.RefineModelForProvider(schemas.Replicate, "meta/meta-llama-3-8b")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "meta/meta-llama-3-8b", refined)
|
||||
}
|
||||
|
||||
// TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError verifies that
|
||||
// refinement fails fast when multiple nested provider slugs match the same base model.
|
||||
func TestRefineModelForProvider_ReplicateReturnsAmbiguousMatchError(t *testing.T) {
|
||||
mc := newTestCatalog(map[schemas.ModelProvider][]string{
|
||||
schemas.Replicate: {
|
||||
"openai/gpt-5-nano",
|
||||
"xai/gpt-5-nano",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
refined, err := mc.RefineModelForProvider(schemas.Replicate, "gpt-5-nano")
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, refined)
|
||||
assert.Contains(t, err.Error(), "multiple compatible models found")
|
||||
}
|
||||
505
framework/modelcatalog/sync.go
Normal file
505
framework/modelcatalog/sync.go
Normal file
@@ -0,0 +1,505 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/tidwall/gjson"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
urlFetchMaxRetries = 3 // retries after the first attempt (4 attempts total)
|
||||
urlFetchMaxBackoff = 10 * time.Second // cap for exponential backoff (steps start at 1s)
|
||||
)
|
||||
|
||||
// syncPricing syncs pricing data from URL to database and updates cache
|
||||
func (mc *ModelCatalog) syncPricing(ctx context.Context) error {
|
||||
if mc.shouldSyncGate != nil {
|
||||
if !mc.shouldSyncGate(ctx) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Load pricing data from URL
|
||||
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
|
||||
return mc.loadPricingFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
// Check if we have existing data in database
|
||||
pricingRecords, pricingErr := mc.configStore.GetModelPrices(ctx)
|
||||
if pricingErr != nil {
|
||||
return fmt.Errorf("failed to get pricing records: %w", pricingErr)
|
||||
}
|
||||
if len(pricingRecords) > 0 {
|
||||
mc.logger.Warn("failed to fetch pricing from URL, falling back to existing database records: %v", err)
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("failed to load pricing data from URL and no existing data in database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update database in transaction
|
||||
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// Deduplicate and insert new pricing data
|
||||
seen := make(map[string]bool)
|
||||
for modelKey, entry := range pricingData {
|
||||
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
|
||||
// Create composite key for deduplication
|
||||
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
||||
// Skip if already seen
|
||||
if exists, ok := seen[key]; ok && exists {
|
||||
continue
|
||||
}
|
||||
// Mark as seen
|
||||
seen[key] = true
|
||||
if err := mc.configStore.UpsertModelPrices(ctx, &pricing, tx); err != nil {
|
||||
return fmt.Errorf("failed to create pricing record for model %s: %w", pricing.Model, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear seen map
|
||||
seen = nil
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync pricing data to database: %w", err)
|
||||
}
|
||||
|
||||
// Reload cache from database
|
||||
if err := mc.loadPricingFromDatabase(ctx); err != nil {
|
||||
return fmt.Errorf("failed to reload pricing cache: %w", err)
|
||||
}
|
||||
|
||||
// Populate model params cache from pricing datasheet max_output_tokens
|
||||
mc.populateModelParamsFromPricing(pricingData)
|
||||
|
||||
mc.logger.Debug("successfully synced %d pricing records", len(pricingData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateModelParamsFromPricing extracts max_output_tokens from pricing entries
|
||||
// and populates the model params cache so that providers can look up max output
|
||||
// tokens without a separate model-parameters sync.
|
||||
func (mc *ModelCatalog) populateModelParamsFromPricing(pricingData map[string]PricingEntry) {
|
||||
modelParamsEntries := make(map[string]providerUtils.ModelParams)
|
||||
for modelKey, entry := range pricingData {
|
||||
if entry.MaxOutputTokens != nil {
|
||||
modelName := extractModelName(modelKey)
|
||||
modelParamsEntries[modelName] = providerUtils.ModelParams{MaxOutputTokens: entry.MaxOutputTokens}
|
||||
}
|
||||
}
|
||||
if len(modelParamsEntries) > 0 {
|
||||
providerUtils.BulkSetModelParams(modelParamsEntries)
|
||||
mc.logger.Debug("populated %d model params entries from pricing datasheet", len(modelParamsEntries))
|
||||
}
|
||||
}
|
||||
|
||||
// loadPricingFromURL loads pricing data from the remote URL
|
||||
func (mc *ModelCatalog) loadPricingFromURL(ctx context.Context) (map[string]PricingEntry, error) {
|
||||
// Create HTTP client with timeout
|
||||
client := &http.Client{}
|
||||
client.Timeout = DefaultPricingTimeout
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, mc.getPricingURL(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
// Make HTTP request
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download pricing data: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check HTTP status
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download pricing data: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read response body
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read pricing data response: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal JSON data
|
||||
var pricingData map[string]PricingEntry
|
||||
if err := json.Unmarshal(data, &pricingData); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal pricing data: %w", err)
|
||||
}
|
||||
|
||||
mc.logger.Debug("successfully downloaded and parsed %d pricing records", len(pricingData))
|
||||
return pricingData, nil
|
||||
}
|
||||
|
||||
// loadPricingIntoMemoryFromURL loads pricing data from URL into memory cache (when config store is not available)
|
||||
func (mc *ModelCatalog) loadPricingIntoMemoryFromURL(ctx context.Context) error {
|
||||
pricingData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]PricingEntry, error) {
|
||||
return mc.loadPricingFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load pricing data from URL: %w", err)
|
||||
}
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Clear and rebuild the pricing map
|
||||
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingData))
|
||||
for modelKey, entry := range pricingData {
|
||||
pricing := convertPricingDataToTableModelPricing(modelKey, entry)
|
||||
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
||||
mc.pricingData[key] = pricing
|
||||
}
|
||||
|
||||
// Populate model params cache from pricing datasheet max_output_tokens
|
||||
mc.populateModelParamsFromPricing(pricingData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadPricingFromDatabase loads pricing data from database into memory cache
|
||||
func (mc *ModelCatalog) loadPricingFromDatabase(ctx context.Context) error {
|
||||
if mc.configStore == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pricingRecords, err := mc.configStore.GetModelPrices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load pricing from database: %w", err)
|
||||
}
|
||||
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
// Clear and rebuild the pricing map
|
||||
mc.pricingData = make(map[string]configstoreTables.TableModelPricing, len(pricingRecords))
|
||||
for _, pricing := range pricingRecords {
|
||||
key := makeKey(pricing.Model, pricing.Provider, pricing.Mode)
|
||||
mc.pricingData[key] = pricing
|
||||
}
|
||||
|
||||
mc.logger.Debug("loaded %d pricing records from database into memory", len(mc.pricingData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModelParametersFromDatabase bulk-loads model parameters from the DB into the provider
|
||||
// utils cache (startup / ReloadFromDB). The SetCacheMissHandler path still loads one row at
|
||||
// a time on cache miss; both use the same table JSON shape.
|
||||
// Returns the number of rows loaded so callers can decide whether to background-sync from URL.
|
||||
func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (int, error) {
|
||||
if mc.configStore == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
rows, err := mc.configStore.GetModelParameters(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to load model parameters from database: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
mc.logger.Debug("no model parameters rows in database")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
paramsData := make(map[string]json.RawMessage, len(rows))
|
||||
for _, row := range rows {
|
||||
paramsData[row.Model] = json.RawMessage(row.Data)
|
||||
}
|
||||
mc.applyModelParameters(paramsData)
|
||||
mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows))
|
||||
return len(rows), nil
|
||||
}
|
||||
|
||||
// startSyncWorker starts the background sync worker
|
||||
func (mc *ModelCatalog) startSyncWorker(ctx context.Context) {
|
||||
// IMPORTANT: scheduling model
|
||||
//
|
||||
// The sync worker wakes on a fixed ticker (syncWorkerTickerPeriod = 1h).
|
||||
// On each wake it calls checkAndSyncPricing, which checks:
|
||||
//
|
||||
// time.Since(lastSyncTimestamp) >= pricingSyncInterval
|
||||
//
|
||||
// This means:
|
||||
// • pricingSyncInterval defines the *minimum elapsed time* between syncs.
|
||||
// • The actual sync frequency = max(syncWorkerTickerPeriod, pricingSyncInterval).
|
||||
// • Setting pricingSyncInterval < 1h does NOT increase sync frequency —
|
||||
// the hourly ticker is the hard lower bound on check granularity.
|
||||
//
|
||||
// Design rationale: avoids high-frequency polling while allowing operators to
|
||||
// tune how stale pricing data can get (e.g., 1h vs 24h vs 7d).
|
||||
mc.syncTicker = time.NewTicker(syncWorkerTickerPeriod)
|
||||
mc.wg.Add(1)
|
||||
go mc.syncWorker(ctx)
|
||||
}
|
||||
|
||||
// withDistributedLock acquires a named distributed lock and executes fn under it.
|
||||
// Pass retries=0 to block until acquired (Lock); pass retries>0 to use LockWithRetry.
|
||||
func (mc *ModelCatalog) withDistributedLock(ctx context.Context, key string, retries int, fn func() error) error {
|
||||
lock, err := mc.distributedLockManager.NewLock(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create lock %q: %w", key, err)
|
||||
}
|
||||
if retries > 0 {
|
||||
if err := lock.LockWithRetry(ctx, retries); err != nil {
|
||||
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
|
||||
}
|
||||
} else {
|
||||
if err := lock.Lock(ctx); err != nil {
|
||||
return fmt.Errorf("failed to acquire lock %q: %w", key, err)
|
||||
}
|
||||
}
|
||||
// Use a fresh context for unlock so that a cancelled or timed-out work context
|
||||
// does not prevent the lock row from being deleted. If we reused ctx and it was
|
||||
// already cancelled when the defer fires, ReleaseLock's DB call would fail
|
||||
// silently and the lock would stay in the database until TTL expiry (30s),
|
||||
// blocking every other node from acquiring it during that window.
|
||||
defer func() {
|
||||
if err := lock.Unlock(context.Background()); err != nil {
|
||||
mc.logger.Warn("failed to release distributed lock %q: %v", key, err)
|
||||
}
|
||||
}()
|
||||
return fn()
|
||||
}
|
||||
|
||||
// syncTick performs a single sync tick with proper lock management
|
||||
// if the last sync was more than the sync interval ago, sync pricing and model parameters in parallel
|
||||
func (mc *ModelCatalog) syncTick(ctx context.Context) {
|
||||
mc.syncMu.RLock()
|
||||
lastSync := mc.lastSyncedAt
|
||||
interval := mc.syncInterval
|
||||
mc.syncMu.RUnlock()
|
||||
|
||||
if time.Since(lastSync) >= interval {
|
||||
mc.logger.Debug("starting model catalog background sync")
|
||||
if err := mc.withDistributedLock(ctx, "model_catalog_pricing_sync", 10, func() error {
|
||||
// Sync pricing and model parameters in parallel
|
||||
var wg sync.WaitGroup
|
||||
var pricingErr, paramsErr error
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncPricing(ctx); err != nil {
|
||||
mc.logger.Error("background pricing sync failed: %v", err)
|
||||
pricingErr = err
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := mc.syncModelParameters(ctx); err != nil {
|
||||
mc.logger.Error("background model parameters sync failed: %v", err)
|
||||
paramsErr = err
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if pricingErr == nil && paramsErr == nil {
|
||||
if mc.afterSyncHook != nil {
|
||||
mc.afterSyncHook(ctx)
|
||||
}
|
||||
mc.syncMu.Lock()
|
||||
mc.lastSyncedAt = time.Now()
|
||||
mc.syncMu.Unlock()
|
||||
}
|
||||
if pricingErr != nil {
|
||||
return pricingErr
|
||||
}
|
||||
return paramsErr
|
||||
}); err != nil {
|
||||
mc.logger.Error("failed to run model catalog sync: %v", err)
|
||||
}
|
||||
mc.logger.Debug("model catalog background sync completed")
|
||||
}
|
||||
}
|
||||
|
||||
// syncWorker runs the background sync check
|
||||
func (mc *ModelCatalog) syncWorker(ctx context.Context) {
|
||||
defer mc.wg.Done()
|
||||
defer mc.syncTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-mc.syncTicker.C:
|
||||
mc.syncTick(ctx)
|
||||
case <-mc.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Model Parameters sync ---
|
||||
|
||||
func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) {
|
||||
modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData))
|
||||
newResponseTypes := make(map[string][]string, len(paramsData))
|
||||
newParamsIndex := make(map[string][]string, len(paramsData))
|
||||
|
||||
for model, rawData := range paramsData {
|
||||
var parsed modelParametersParseResult
|
||||
if err := json.Unmarshal(rawData, &parsed); err != nil {
|
||||
mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err)
|
||||
continue
|
||||
}
|
||||
|
||||
outputs := make([]string, 0, len(parsed.SupportedEndpoints))
|
||||
for _, endpoint := range parsed.SupportedEndpoints {
|
||||
if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) {
|
||||
outputs = append(outputs, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Mode != nil {
|
||||
if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) {
|
||||
outputs = append(outputs, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Contains(outputs, "text_completion") {
|
||||
provider := gjson.GetBytes(rawData, "provider")
|
||||
if provider.Exists() {
|
||||
key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest))
|
||||
|
||||
mc.mu.RLock()
|
||||
_, ok := mc.pricingData[key]
|
||||
mc.mu.RUnlock()
|
||||
if ok {
|
||||
outputs = append(outputs, "text_completion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(outputs) > 0 {
|
||||
newResponseTypes[model] = outputs
|
||||
}
|
||||
|
||||
supported := extractSupportedParams(&parsed)
|
||||
if len(supported) > 0 {
|
||||
newParamsIndex[model] = supported
|
||||
}
|
||||
|
||||
var p struct {
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
}
|
||||
if p.MaxOutputTokens == nil {
|
||||
if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil {
|
||||
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
||||
}
|
||||
} else {
|
||||
modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens}
|
||||
}
|
||||
}
|
||||
|
||||
mc.mu.Lock()
|
||||
mc.supportedResponseTypes = newResponseTypes
|
||||
mc.supportedParams = newParamsIndex
|
||||
mc.mu.Unlock()
|
||||
|
||||
if len(modelParamsEntries) > 0 {
|
||||
providerUtils.BulkSetModelParams(modelParamsEntries)
|
||||
}
|
||||
}
|
||||
|
||||
// loadModelParametersIntoMemoryFromURL loads model parameters from the remote URL into the
|
||||
// provider utils cache (when config store is not available).
|
||||
func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context) error {
|
||||
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
|
||||
return mc.loadModelParametersFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load model parameters from URL: %w", err)
|
||||
}
|
||||
mc.applyModelParameters(paramsData)
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncModelParameters syncs model parameters data from URL into memory cache
|
||||
func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error {
|
||||
if mc.shouldSyncGate != nil {
|
||||
if !mc.shouldSyncGate(ctx) {
|
||||
mc.logger.Debug("model parameters sync cancelled by custom gate")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
mc.logger.Debug("starting model parameters synchronization")
|
||||
|
||||
paramsData, err := WithRetries(ctx, urlFetchMaxRetries, urlFetchMaxBackoff, func() (map[string]json.RawMessage, error) {
|
||||
return mc.loadModelParametersFromURL(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
if mc.configStore != nil {
|
||||
rows, dbErr := mc.configStore.GetModelParameters(ctx)
|
||||
if dbErr == nil && len(rows) > 0 {
|
||||
mc.logger.Error("failed to load model parameters from URL, falling back to existing database records: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to load model parameters from URL and no existing data in database: %w", err)
|
||||
}
|
||||
|
||||
// Persist to database if config store is available
|
||||
if mc.configStore != nil {
|
||||
err = mc.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error {
|
||||
for model, data := range paramsData {
|
||||
params := &configstoreTables.TableModelParameters{
|
||||
Model: model,
|
||||
Data: string(data),
|
||||
}
|
||||
if err := mc.configStore.UpsertModelParameters(ctx, params, tx); err != nil {
|
||||
return fmt.Errorf("failed to upsert model parameters for model %s: %w", model, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync model parameters to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
mc.applyModelParameters(paramsData)
|
||||
|
||||
mc.logger.Info("successfully synced %d model parameters records", len(paramsData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModelParametersFromURL loads model parameters data from the remote URL
|
||||
func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[string]json.RawMessage, error) {
|
||||
client := &http.Client{}
|
||||
client.Timeout = DefaultModelParametersTimeout
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DefaultModelParametersURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download model parameters data: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to download model parameters data: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read model parameters response: %w", err)
|
||||
}
|
||||
|
||||
var paramsData map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, ¶msData); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal model parameters data: %w", err)
|
||||
}
|
||||
|
||||
mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData))
|
||||
return paramsData, nil
|
||||
}
|
||||
441
framework/modelcatalog/utils.go
Normal file
441
framework/modelcatalog/utils.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package modelcatalog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
const retryBackoffMin = time.Second
|
||||
|
||||
// WithRetries runs op until it succeeds or maxRetries retries are exhausted
|
||||
// (1 initial attempt + maxRetries retries). After each failure it waits with
|
||||
// exponential backoff starting at 1 second (retryBackoffMin), capped at maxBackoff
|
||||
// when maxBackoff > 0. If maxBackoff is zero, there is no upper cap on the delay.
|
||||
func WithRetries[T any](ctx context.Context, maxRetries int, maxBackoff time.Duration, op func() (T, error)) (T, error) {
|
||||
var zero T
|
||||
if maxRetries < 0 {
|
||||
maxRetries = 0
|
||||
}
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if attempt > 0 {
|
||||
backoff := retryBackoffMin * time.Duration(1<<uint(attempt-1))
|
||||
if maxBackoff > 0 && backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
v, err := op()
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return zero, lastErr
|
||||
}
|
||||
|
||||
// makeKey creates a unique key for a model, provider, and mode for pricingData map
|
||||
func makeKey(model, provider, mode string) string { return model + "|" + provider + "|" + mode }
|
||||
|
||||
// normalizeProvider normalizes the provider name to a consistent format
|
||||
func normalizeProvider(p string) string {
|
||||
if strings.Contains(p, "vertex_ai") || p == "google-vertex" {
|
||||
return string(schemas.Vertex)
|
||||
} else if strings.Contains(p, "bedrock") {
|
||||
return string(schemas.Bedrock)
|
||||
} else if strings.Contains(p, "cohere") {
|
||||
return string(schemas.Cohere)
|
||||
} else if strings.Contains(p, "runwayml") {
|
||||
return string(schemas.Runway)
|
||||
} else if strings.Contains(p, "fireworks_ai") {
|
||||
return string(schemas.Fireworks)
|
||||
} else {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeRequestType normalizes the request type to a consistent format
|
||||
func normalizeRequestType(reqType schemas.RequestType) string {
|
||||
baseType := "unknown"
|
||||
|
||||
switch reqType {
|
||||
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
|
||||
baseType = "completion"
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
baseType = "chat"
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.RealtimeRequest:
|
||||
baseType = "responses"
|
||||
case schemas.EmbeddingRequest:
|
||||
baseType = "embedding"
|
||||
case schemas.RerankRequest:
|
||||
baseType = "rerank"
|
||||
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
|
||||
baseType = "audio_speech"
|
||||
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
|
||||
baseType = "audio_transcription"
|
||||
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest, schemas.ImageVariationRequest:
|
||||
baseType = "image_generation"
|
||||
case schemas.ImageEditRequest, schemas.ImageEditStreamRequest:
|
||||
baseType = "image_edit"
|
||||
case schemas.VideoGenerationRequest, schemas.VideoRemixRequest:
|
||||
baseType = "video_generation"
|
||||
case schemas.OCRRequest:
|
||||
baseType = "ocr"
|
||||
}
|
||||
|
||||
return baseType
|
||||
}
|
||||
|
||||
// normalizeStreamRequestType normalizes the stream request type to a consistent format
|
||||
// It returns the base request type for the stream request type.
|
||||
func normalizeStreamRequestType(rt schemas.RequestType) schemas.RequestType {
|
||||
switch rt {
|
||||
case schemas.TextCompletionStreamRequest:
|
||||
return schemas.TextCompletionRequest
|
||||
case schemas.ChatCompletionStreamRequest:
|
||||
return schemas.ChatCompletionRequest
|
||||
case schemas.ResponsesStreamRequest:
|
||||
return schemas.ResponsesRequest
|
||||
case schemas.RealtimeRequest:
|
||||
return schemas.RealtimeRequest
|
||||
case schemas.SpeechStreamRequest:
|
||||
return schemas.SpeechRequest
|
||||
case schemas.TranscriptionStreamRequest:
|
||||
return schemas.TranscriptionRequest
|
||||
case schemas.ImageGenerationStreamRequest:
|
||||
return schemas.ImageGenerationRequest
|
||||
case schemas.ImageEditStreamRequest:
|
||||
return schemas.ImageEditRequest
|
||||
default:
|
||||
return rt
|
||||
}
|
||||
}
|
||||
|
||||
// extractModelName extracts the model name from a model key that may be in provider/model format
|
||||
func extractModelName(modelKey string) string {
|
||||
if strings.Contains(modelKey, "/") {
|
||||
parts := strings.Split(modelKey, "/")
|
||||
if len(parts) > 1 {
|
||||
return strings.Join(parts[1:], "/")
|
||||
}
|
||||
}
|
||||
return modelKey
|
||||
}
|
||||
|
||||
// convertPricingDataToTableModelPricing converts the pricing data to a TableModelPricing struct
|
||||
func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) configstoreTables.TableModelPricing {
|
||||
provider := normalizeProvider(entry.Provider)
|
||||
modelName := extractModelName(modelKey)
|
||||
|
||||
return configstoreTables.TableModelPricing{
|
||||
Model: modelName,
|
||||
BaseModel: entry.BaseModel,
|
||||
Provider: provider,
|
||||
Mode: entry.Mode,
|
||||
ContextLength: entry.ContextLength,
|
||||
MaxInputTokens: entry.MaxInputTokens,
|
||||
MaxOutputTokens: entry.MaxOutputTokens,
|
||||
Architecture: entry.Architecture,
|
||||
|
||||
// Costs - Text
|
||||
InputCostPerToken: entry.InputCostPerToken,
|
||||
OutputCostPerToken: entry.OutputCostPerToken,
|
||||
InputCostPerTokenBatches: entry.InputCostPerTokenBatches,
|
||||
OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches,
|
||||
InputCostPerTokenPriority: entry.InputCostPerTokenPriority,
|
||||
OutputCostPerTokenPriority: entry.OutputCostPerTokenPriority,
|
||||
InputCostPerTokenFlex: entry.InputCostPerTokenFlex,
|
||||
OutputCostPerTokenFlex: entry.OutputCostPerTokenFlex,
|
||||
InputCostPerTokenAbove200kTokens: entry.InputCostPerTokenAbove200kTokens,
|
||||
InputCostPerTokenAbove200kTokensPriority: entry.InputCostPerTokenAbove200kTokensPriority,
|
||||
OutputCostPerTokenAbove200kTokens: entry.OutputCostPerTokenAbove200kTokens,
|
||||
OutputCostPerTokenAbove200kTokensPriority: entry.OutputCostPerTokenAbove200kTokensPriority,
|
||||
// Costs - 272k Tier
|
||||
InputCostPerTokenAbove272kTokens: entry.InputCostPerTokenAbove272kTokens,
|
||||
InputCostPerTokenAbove272kTokensPriority: entry.InputCostPerTokenAbove272kTokensPriority,
|
||||
OutputCostPerTokenAbove272kTokens: entry.OutputCostPerTokenAbove272kTokens,
|
||||
OutputCostPerTokenAbove272kTokensPriority: entry.OutputCostPerTokenAbove272kTokensPriority,
|
||||
// Costs - Character
|
||||
InputCostPerCharacter: entry.InputCostPerCharacter,
|
||||
// Costs - 128k Tier
|
||||
InputCostPerTokenAbove128kTokens: entry.InputCostPerTokenAbove128kTokens,
|
||||
InputCostPerImageAbove128kTokens: entry.InputCostPerImageAbove128kTokens,
|
||||
InputCostPerVideoPerSecondAbove128kTokens: entry.InputCostPerVideoPerSecondAbove128kTokens,
|
||||
InputCostPerAudioPerSecondAbove128kTokens: entry.InputCostPerAudioPerSecondAbove128kTokens,
|
||||
OutputCostPerTokenAbove128kTokens: entry.OutputCostPerTokenAbove128kTokens,
|
||||
|
||||
// Costs - Cache
|
||||
CacheCreationInputTokenCost: entry.CacheCreationInputTokenCost,
|
||||
CacheReadInputTokenCost: entry.CacheReadInputTokenCost,
|
||||
CacheCreationInputTokenCostAbove200kTokens: entry.CacheCreationInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokens: entry.CacheReadInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokensPriority: entry.CacheReadInputTokenCostAbove200kTokensPriority,
|
||||
CacheCreationInputTokenCostAbove1hr: entry.CacheCreationInputTokenCostAbove1hr,
|
||||
CacheCreationInputTokenCostAbove1hrAbove200kTokens: entry.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
|
||||
CacheCreationInputAudioTokenCost: entry.CacheCreationInputAudioTokenCost,
|
||||
CacheReadInputTokenCostPriority: entry.CacheReadInputTokenCostPriority,
|
||||
CacheReadInputTokenCostFlex: entry.CacheReadInputTokenCostFlex,
|
||||
CacheReadInputImageTokenCost: entry.CacheReadInputImageTokenCost,
|
||||
CacheReadInputTokenCostAbove272kTokens: entry.CacheReadInputTokenCostAbove272kTokens,
|
||||
CacheReadInputTokenCostAbove272kTokensPriority: entry.CacheReadInputTokenCostAbove272kTokensPriority,
|
||||
|
||||
// Costs - Image
|
||||
InputCostPerImage: entry.InputCostPerImage,
|
||||
InputCostPerPixel: entry.InputCostPerPixel,
|
||||
OutputCostPerImage: entry.OutputCostPerImage,
|
||||
OutputCostPerPixel: entry.OutputCostPerPixel,
|
||||
OutputCostPerImagePremiumImage: entry.OutputCostPerImagePremiumImage,
|
||||
OutputCostPerImageAbove512x512Pixels: entry.OutputCostPerImageAbove512x512Pixels,
|
||||
OutputCostPerImageAbove512x512PixelsPremium: entry.OutputCostPerImageAbove512x512PixelsPremium,
|
||||
OutputCostPerImageAbove1024x1024Pixels: entry.OutputCostPerImageAbove1024x1024Pixels,
|
||||
OutputCostPerImageAbove1024x1024PixelsPremium: entry.OutputCostPerImageAbove1024x1024PixelsPremium,
|
||||
OutputCostPerImageAbove2048x2048Pixels: entry.OutputCostPerImageAbove2048x2048Pixels,
|
||||
OutputCostPerImageAbove4096x4096Pixels: entry.OutputCostPerImageAbove4096x4096Pixels,
|
||||
OutputCostPerImageLowQuality: entry.OutputCostPerImageLowQuality,
|
||||
OutputCostPerImageMediumQuality: entry.OutputCostPerImageMediumQuality,
|
||||
OutputCostPerImageHighQuality: entry.OutputCostPerImageHighQuality,
|
||||
OutputCostPerImageAutoQuality: entry.OutputCostPerImageAutoQuality,
|
||||
// Costs - Image Token
|
||||
InputCostPerImageToken: entry.InputCostPerImageToken,
|
||||
OutputCostPerImageToken: entry.OutputCostPerImageToken,
|
||||
|
||||
// Costs - Audio/Video
|
||||
InputCostPerAudioToken: entry.InputCostPerAudioToken,
|
||||
InputCostPerAudioPerSecond: entry.InputCostPerAudioPerSecond,
|
||||
InputCostPerSecond: entry.InputCostPerSecond,
|
||||
InputCostPerVideoPerSecond: entry.InputCostPerVideoPerSecond,
|
||||
OutputCostPerAudioToken: entry.OutputCostPerAudioToken,
|
||||
OutputCostPerVideoPerSecond: entry.OutputCostPerVideoPerSecond,
|
||||
OutputCostPerSecond: entry.OutputCostPerSecond,
|
||||
|
||||
// Costs - Other
|
||||
SearchContextCostPerQuery: entry.SearchContextCostPerQuery,
|
||||
CodeInterpreterCostPerSession: entry.CodeInterpreterCostPerSession,
|
||||
|
||||
// Costs - OCR
|
||||
OCRCostPerPage: entry.OCRCostPerPage,
|
||||
AnnotationCostPerPage: entry.AnnotationCostPerPage,
|
||||
}
|
||||
}
|
||||
|
||||
// convertTableModelPricingToPricingData converts the TableModelPricing struct to a PricingEntry struct
|
||||
func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry {
|
||||
options := PricingOptions{
|
||||
// Costs - Text
|
||||
InputCostPerToken: pricing.InputCostPerToken,
|
||||
OutputCostPerToken: pricing.OutputCostPerToken,
|
||||
InputCostPerTokenBatches: pricing.InputCostPerTokenBatches,
|
||||
OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches,
|
||||
InputCostPerTokenPriority: pricing.InputCostPerTokenPriority,
|
||||
OutputCostPerTokenPriority: pricing.OutputCostPerTokenPriority,
|
||||
InputCostPerTokenFlex: pricing.InputCostPerTokenFlex,
|
||||
OutputCostPerTokenFlex: pricing.OutputCostPerTokenFlex,
|
||||
InputCostPerTokenAbove200kTokens: pricing.InputCostPerTokenAbove200kTokens,
|
||||
InputCostPerTokenAbove200kTokensPriority: pricing.InputCostPerTokenAbove200kTokensPriority,
|
||||
OutputCostPerTokenAbove200kTokens: pricing.OutputCostPerTokenAbove200kTokens,
|
||||
OutputCostPerTokenAbove200kTokensPriority: pricing.OutputCostPerTokenAbove200kTokensPriority,
|
||||
// Costs - 272k Tier
|
||||
InputCostPerTokenAbove272kTokens: pricing.InputCostPerTokenAbove272kTokens,
|
||||
InputCostPerTokenAbove272kTokensPriority: pricing.InputCostPerTokenAbove272kTokensPriority,
|
||||
OutputCostPerTokenAbove272kTokens: pricing.OutputCostPerTokenAbove272kTokens,
|
||||
OutputCostPerTokenAbove272kTokensPriority: pricing.OutputCostPerTokenAbove272kTokensPriority,
|
||||
// Costs - Character
|
||||
InputCostPerCharacter: pricing.InputCostPerCharacter,
|
||||
// Costs - 128k Tier
|
||||
InputCostPerTokenAbove128kTokens: pricing.InputCostPerTokenAbove128kTokens,
|
||||
InputCostPerImageAbove128kTokens: pricing.InputCostPerImageAbove128kTokens,
|
||||
InputCostPerVideoPerSecondAbove128kTokens: pricing.InputCostPerVideoPerSecondAbove128kTokens,
|
||||
InputCostPerAudioPerSecondAbove128kTokens: pricing.InputCostPerAudioPerSecondAbove128kTokens,
|
||||
OutputCostPerTokenAbove128kTokens: pricing.OutputCostPerTokenAbove128kTokens,
|
||||
|
||||
// Costs - Cache
|
||||
CacheCreationInputTokenCost: pricing.CacheCreationInputTokenCost,
|
||||
CacheReadInputTokenCost: pricing.CacheReadInputTokenCost,
|
||||
CacheCreationInputTokenCostAbove200kTokens: pricing.CacheCreationInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokens: pricing.CacheReadInputTokenCostAbove200kTokens,
|
||||
CacheReadInputTokenCostAbove200kTokensPriority: pricing.CacheReadInputTokenCostAbove200kTokensPriority,
|
||||
CacheCreationInputTokenCostAbove1hr: pricing.CacheCreationInputTokenCostAbove1hr,
|
||||
CacheCreationInputTokenCostAbove1hrAbove200kTokens: pricing.CacheCreationInputTokenCostAbove1hrAbove200kTokens,
|
||||
CacheCreationInputAudioTokenCost: pricing.CacheCreationInputAudioTokenCost,
|
||||
CacheReadInputTokenCostPriority: pricing.CacheReadInputTokenCostPriority,
|
||||
CacheReadInputTokenCostFlex: pricing.CacheReadInputTokenCostFlex,
|
||||
CacheReadInputImageTokenCost: pricing.CacheReadInputImageTokenCost,
|
||||
CacheReadInputTokenCostAbove272kTokens: pricing.CacheReadInputTokenCostAbove272kTokens,
|
||||
CacheReadInputTokenCostAbove272kTokensPriority: pricing.CacheReadInputTokenCostAbove272kTokensPriority,
|
||||
|
||||
// Costs - Image
|
||||
InputCostPerImage: pricing.InputCostPerImage,
|
||||
InputCostPerPixel: pricing.InputCostPerPixel,
|
||||
OutputCostPerImage: pricing.OutputCostPerImage,
|
||||
OutputCostPerPixel: pricing.OutputCostPerPixel,
|
||||
OutputCostPerImagePremiumImage: pricing.OutputCostPerImagePremiumImage,
|
||||
OutputCostPerImageAbove512x512Pixels: pricing.OutputCostPerImageAbove512x512Pixels,
|
||||
OutputCostPerImageAbove512x512PixelsPremium: pricing.OutputCostPerImageAbove512x512PixelsPremium,
|
||||
OutputCostPerImageAbove1024x1024Pixels: pricing.OutputCostPerImageAbove1024x1024Pixels,
|
||||
OutputCostPerImageAbove1024x1024PixelsPremium: pricing.OutputCostPerImageAbove1024x1024PixelsPremium,
|
||||
OutputCostPerImageAbove2048x2048Pixels: pricing.OutputCostPerImageAbove2048x2048Pixels,
|
||||
OutputCostPerImageAbove4096x4096Pixels: pricing.OutputCostPerImageAbove4096x4096Pixels,
|
||||
OutputCostPerImageLowQuality: pricing.OutputCostPerImageLowQuality,
|
||||
OutputCostPerImageMediumQuality: pricing.OutputCostPerImageMediumQuality,
|
||||
OutputCostPerImageHighQuality: pricing.OutputCostPerImageHighQuality,
|
||||
OutputCostPerImageAutoQuality: pricing.OutputCostPerImageAutoQuality,
|
||||
// Costs - Image Token
|
||||
InputCostPerImageToken: pricing.InputCostPerImageToken,
|
||||
OutputCostPerImageToken: pricing.OutputCostPerImageToken,
|
||||
|
||||
// Costs - Audio/Video
|
||||
InputCostPerAudioToken: pricing.InputCostPerAudioToken,
|
||||
InputCostPerAudioPerSecond: pricing.InputCostPerAudioPerSecond,
|
||||
InputCostPerSecond: pricing.InputCostPerSecond,
|
||||
InputCostPerVideoPerSecond: pricing.InputCostPerVideoPerSecond,
|
||||
OutputCostPerAudioToken: pricing.OutputCostPerAudioToken,
|
||||
OutputCostPerVideoPerSecond: pricing.OutputCostPerVideoPerSecond,
|
||||
OutputCostPerSecond: pricing.OutputCostPerSecond,
|
||||
|
||||
// Costs - Other
|
||||
SearchContextCostPerQuery: pricing.SearchContextCostPerQuery,
|
||||
CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession,
|
||||
|
||||
// Costs - OCR
|
||||
OCRCostPerPage: pricing.OCRCostPerPage,
|
||||
AnnotationCostPerPage: pricing.AnnotationCostPerPage,
|
||||
}
|
||||
return &PricingEntry{
|
||||
BaseModel: pricing.BaseModel,
|
||||
Provider: pricing.Provider,
|
||||
Mode: pricing.Mode,
|
||||
ContextLength: pricing.ContextLength,
|
||||
MaxInputTokens: pricing.MaxInputTokens,
|
||||
MaxOutputTokens: pricing.MaxOutputTokens,
|
||||
Architecture: pricing.Architecture,
|
||||
PricingOptions: options,
|
||||
}
|
||||
}
|
||||
|
||||
// convertTablePricingOverrideToPricingOverride converts a TablePricingOverride to a PricingOverride.
|
||||
func convertTablePricingOverrideToPricingOverride(override *configstoreTables.TablePricingOverride) (PricingOverride, error) {
|
||||
var options PricingOptions
|
||||
if err := sonic.Unmarshal([]byte(override.PricingPatchJSON), &options); err != nil {
|
||||
return PricingOverride{}, err
|
||||
}
|
||||
return PricingOverride{
|
||||
ID: override.ID,
|
||||
Name: override.Name,
|
||||
ScopeKind: ScopeKind(override.ScopeKind),
|
||||
VirtualKeyID: override.VirtualKeyID,
|
||||
ProviderID: override.ProviderID,
|
||||
ProviderKeyID: override.ProviderKeyID,
|
||||
MatchType: MatchType(override.MatchType),
|
||||
Pattern: override.Pattern,
|
||||
RequestTypes: override.RequestTypes,
|
||||
Options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeEndpointToOutputType converts a supported_endpoints URL path to a normalized output type.
|
||||
// Returns empty string for unrecognized endpoints.
|
||||
func normalizeEndpointToOutputType(endpoint string) string {
|
||||
switch {
|
||||
case strings.Contains(endpoint, "/chat/completions"):
|
||||
return "chat_completion"
|
||||
case strings.Contains(endpoint, "/responses"):
|
||||
return "responses"
|
||||
case strings.Contains(endpoint, "/completions"):
|
||||
return "text_completion"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeModeToOutputType converts mode to a normalized output type.
|
||||
func normalizeModeToOutputType(mode string) string {
|
||||
switch mode {
|
||||
case "chat":
|
||||
return "chat_completion"
|
||||
case "completion":
|
||||
return "text_completion"
|
||||
case "responses":
|
||||
return "responses"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex.
|
||||
type modelParametersParseResult struct {
|
||||
Mode *string `json:"mode,omitempty"`
|
||||
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
|
||||
ModelParameters []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"model_parameters,omitempty"`
|
||||
SupportsFunctionCalling *bool `json:"supports_function_calling,omitempty"`
|
||||
SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling,omitempty"`
|
||||
SupportsToolChoice *bool `json:"supports_tool_choice,omitempty"`
|
||||
SupportsReasoning *bool `json:"supports_reasoning,omitempty"`
|
||||
SupportsServiceTier *bool `json:"supports_service_tier,omitempty"`
|
||||
SupportsPromptCaching *bool `json:"supports_prompt_caching,omitempty"`
|
||||
}
|
||||
|
||||
// extractSupportedParams builds a list of supported OpenAI-compatible parameter
|
||||
// names from model_parameters[].id values and supports_* boolean flags.
|
||||
func extractSupportedParams(parsed *modelParametersParseResult) []string {
|
||||
var supported []string
|
||||
addParam := func(name string) {
|
||||
if !slices.Contains(supported, name) {
|
||||
supported = append(supported, name)
|
||||
}
|
||||
}
|
||||
|
||||
// From model_parameters[].id — map IDs to request param names
|
||||
for _, mp := range parsed.ModelParameters {
|
||||
switch mp.ID {
|
||||
case "reasoning_effort", "reasoning_summary":
|
||||
addParam("reasoning")
|
||||
case "web_search":
|
||||
addParam("web_search_options")
|
||||
case "promptTools", "image_detail", "stream":
|
||||
// skip — not top-level request parameters
|
||||
default:
|
||||
addParam(mp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// From supports_* boolean flags
|
||||
if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling {
|
||||
addParam("tools")
|
||||
}
|
||||
if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling {
|
||||
addParam("parallel_tool_calls")
|
||||
}
|
||||
if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice {
|
||||
addParam("tool_choice")
|
||||
}
|
||||
if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning {
|
||||
addParam("reasoning")
|
||||
}
|
||||
if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier {
|
||||
addParam("service_tier")
|
||||
}
|
||||
if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching {
|
||||
addParam("prompt_cache_key")
|
||||
addParam("prompt_cache_retention")
|
||||
}
|
||||
|
||||
return supported
|
||||
}
|
||||
454
framework/oauth2/discovery.go
Normal file
454
framework/oauth2/discovery.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthMetadata contains discovered OAuth configuration from authorization server
|
||||
type OAuthMetadata struct {
|
||||
AuthorizationURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
RegistrationURL *string `json:"registration_endpoint,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
ResponseTypes []string `json:"response_types_supported,omitempty"`
|
||||
GrantTypes []string `json:"grant_types_supported,omitempty"`
|
||||
TokenAuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
PKCEMethods []string `json:"code_challenge_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
// ResourceMetadata contains metadata from protected resource
|
||||
type ResourceMetadata struct {
|
||||
AuthorizationServers []string `json:"authorization_servers"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"` // Alternative field name
|
||||
}
|
||||
|
||||
// DiscoverOAuthMetadata performs OAuth 2.0 discovery for the given MCP server URL
|
||||
// Following RFC 8414 (Authorization Server Discovery) and RFC 9728 (Protected Resource Metadata)
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the discovery requests
|
||||
// - serverURL: The MCP server URL to discover OAuth configuration from
|
||||
// - logger: Logger for discovery progress (can be nil for silent operation)
|
||||
//
|
||||
// The discovery process:
|
||||
// 1. Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
|
||||
// 2. Parse WWW-Authenticate header for resource_metadata URL and scopes
|
||||
// 3. Fetch resource metadata to get authorization server URLs
|
||||
// 4. Try .well-known discovery if resource metadata is not available
|
||||
// 5. Fetch authorization server metadata from discovered URLs
|
||||
// 6. Return complete OAuth configuration
|
||||
func DiscoverOAuthMetadata(ctx context.Context, serverURL string) (*OAuthMetadata, error) {
|
||||
if logger != nil {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Starting discovery for server: %s", serverURL))
|
||||
}
|
||||
|
||||
// Step 1: Attempt to connect to MCP server, expect 401 with WWW-Authenticate header
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", serverURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to server: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Server responded with status: %d", resp.StatusCode))
|
||||
|
||||
// Step 2: Parse WWW-Authenticate header
|
||||
wwwAuth := resp.Header.Get("WWW-Authenticate")
|
||||
if wwwAuth == "" {
|
||||
wwwAuth = resp.Header.Get("www-authenticate")
|
||||
}
|
||||
|
||||
resourceMetadataURL, scopesFromHeader := parseWWWAuthenticateHeader(wwwAuth)
|
||||
if resourceMetadataURL != "" {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found resource_metadata URL: %s", resourceMetadataURL))
|
||||
}
|
||||
if len(scopesFromHeader) > 0 {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found scopes in header: %v", scopesFromHeader))
|
||||
}
|
||||
|
||||
// Step 3: Fetch resource metadata if available
|
||||
var authServers []string
|
||||
var resourceScopes []string
|
||||
|
||||
if resourceMetadataURL != "" {
|
||||
authServers, resourceScopes, err = fetchResourceMetadata(ctx, resourceMetadataURL)
|
||||
if err != nil {
|
||||
// Log but continue to well-known discovery
|
||||
logger.Warn(fmt.Sprintf("[OAuth Discovery] Failed to fetch resource metadata: %v", err))
|
||||
} else {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from resource metadata", len(authServers)))
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Try well-known discovery if no resource metadata
|
||||
if len(authServers) == 0 {
|
||||
logger.Debug("[OAuth Discovery] Attempting .well-known discovery")
|
||||
authServers, resourceScopes, err = attemptWellKnownDiscovery(ctx, serverURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OAuth discovery failed: %w", err)
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found %d authorization servers from .well-known", len(authServers)))
|
||||
}
|
||||
|
||||
// Step 5: Fetch authorization server metadata
|
||||
metadata, err := fetchAuthorizationServerMetadata(ctx, authServers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch authorization server metadata: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Merge scopes (priority: header > resource metadata > discovered)
|
||||
if len(scopesFromHeader) > 0 {
|
||||
metadata.ScopesSupported = scopesFromHeader
|
||||
} else if len(resourceScopes) > 0 {
|
||||
metadata.ScopesSupported = resourceScopes
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully discovered OAuth metadata for %s", serverURL))
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Authorization URL: %s", metadata.AuthorizationURL))
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Token URL: %s", metadata.TokenURL))
|
||||
if metadata.RegistrationURL != nil {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Registration URL: %s", *metadata.RegistrationURL))
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Scopes: %v", metadata.ScopesSupported))
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// parseWWWAuthenticateHeader extracts resource_metadata URL and scopes from WWW-Authenticate header
|
||||
// Example header: Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read write"
|
||||
func parseWWWAuthenticateHeader(header string) (resourceMetadataURL string, scopes []string) {
|
||||
if header == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Extract parameters from header
|
||||
// Pattern matches: param_name="value" or param_name=value
|
||||
paramPattern := regexp.MustCompile(`([a-zA-Z0-9_]+)\s*=\s*"?([^",]+)"?`)
|
||||
matches := paramPattern.FindAllStringSubmatch(header, -1)
|
||||
|
||||
params := make(map[string]string)
|
||||
for _, match := range matches {
|
||||
if len(match) == 3 {
|
||||
params[strings.ToLower(match[1])] = strings.TrimSpace(match[2])
|
||||
}
|
||||
}
|
||||
|
||||
resourceMetadataURL = params["resource_metadata"]
|
||||
|
||||
if scopeValue := params["scope"]; scopeValue != "" {
|
||||
scopes = strings.Fields(scopeValue)
|
||||
}
|
||||
|
||||
return resourceMetadataURL, scopes
|
||||
}
|
||||
|
||||
// fetchResourceMetadata fetches OAuth metadata from resource metadata endpoint (RFC 9728)
|
||||
func fetchResourceMetadata(ctx context.Context, metadataURL string) ([]string, []string, error) {
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("unexpected status %d from resource metadata endpoint", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data ResourceMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decode resource metadata: %w", err)
|
||||
}
|
||||
|
||||
// Use scopes_supported first, fall back to scopes
|
||||
scopes := data.ScopesSupported
|
||||
if len(scopes) == 0 {
|
||||
scopes = data.Scopes
|
||||
}
|
||||
|
||||
return data.AuthorizationServers, scopes, nil
|
||||
}
|
||||
|
||||
// attemptWellKnownDiscovery tries standard .well-known endpoints for protected resource discovery
|
||||
func attemptWellKnownDiscovery(ctx context.Context, serverURL string) ([]string, []string, error) {
|
||||
// Parse server URL to get base and path
|
||||
base, path := splitURL(serverURL)
|
||||
if base == "" {
|
||||
return nil, nil, fmt.Errorf("invalid server URL: %s", serverURL)
|
||||
}
|
||||
|
||||
// Try different well-known locations
|
||||
var candidateURLs []string
|
||||
if path != "" {
|
||||
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource/%s", base, path))
|
||||
}
|
||||
candidateURLs = append(candidateURLs, fmt.Sprintf("%s/.well-known/oauth-protected-resource", base))
|
||||
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying %d .well-known URLs", len(candidateURLs)))
|
||||
|
||||
for _, candidateURL := range candidateURLs {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying: %s", candidateURL))
|
||||
authServers, scopes, err := fetchResourceMetadata(ctx, candidateURL)
|
||||
if err == nil && len(authServers) > 0 {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Found metadata at: %s", candidateURL))
|
||||
return authServers, scopes, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: assume server base is the authorization server
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] No .well-known found, assuming server base is auth server: %s", base))
|
||||
return []string{base}, nil, nil
|
||||
}
|
||||
|
||||
// fetchAuthorizationServerMetadata fetches OAuth endpoints from authorization server(s)
|
||||
// Tries multiple authorization servers until one succeeds
|
||||
func fetchAuthorizationServerMetadata(ctx context.Context, authServers []string) (*OAuthMetadata, error) {
|
||||
for _, issuer := range authServers {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Fetching metadata from authorization server: %s", issuer))
|
||||
metadata, err := fetchSingleAuthServerMetadata(ctx, issuer)
|
||||
if err == nil && metadata != nil {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Successfully fetched metadata from: %s", issuer))
|
||||
return metadata, nil
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Failed to fetch from %s: %v", issuer, err))
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch metadata from any authorization server")
|
||||
}
|
||||
|
||||
// fetchSingleAuthServerMetadata tries multiple well-known endpoints for a single authorization server
|
||||
// Implements RFC 8414 discovery
|
||||
func fetchSingleAuthServerMetadata(ctx context.Context, issuer string) (*OAuthMetadata, error) {
|
||||
base, path := splitURL(issuer)
|
||||
if base == "" {
|
||||
return nil, fmt.Errorf("invalid issuer URL: %s", issuer)
|
||||
}
|
||||
|
||||
// Try different well-known endpoint patterns
|
||||
var candidateURLs []string
|
||||
if path != "" {
|
||||
candidateURLs = append(candidateURLs,
|
||||
fmt.Sprintf("%s/.well-known/oauth-authorization-server/%s", base, path),
|
||||
fmt.Sprintf("%s/.well-known/openid-configuration/%s", base, path),
|
||||
)
|
||||
}
|
||||
candidateURLs = append(candidateURLs,
|
||||
fmt.Sprintf("%s/.well-known/oauth-authorization-server", base),
|
||||
fmt.Sprintf("%s/.well-known/openid-configuration", base),
|
||||
strings.TrimSuffix(issuer, "/"), // Try the issuer URL itself
|
||||
)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
for _, candidateURL := range candidateURLs {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Trying metadata endpoint: %s", candidateURL))
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", candidateURL, nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var metadata OAuthMetadata
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bodyBytes, &metadata); err == nil {
|
||||
// Validate that we got at least authorization_endpoint
|
||||
if metadata.AuthorizationURL != "" {
|
||||
logger.Debug(fmt.Sprintf("[OAuth Discovery] Valid metadata found at: %s", candidateURL))
|
||||
return &metadata, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no valid metadata found for issuer: %s", issuer)
|
||||
}
|
||||
|
||||
// splitURL splits a URL into base (scheme://host) and path
|
||||
func splitURL(urlStr string) (base, path string) {
|
||||
// Parse URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Build base URL (scheme + host)
|
||||
base = fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
|
||||
|
||||
// Get path without leading slash
|
||||
path = strings.TrimPrefix(parsedURL.Path, "/")
|
||||
|
||||
return base, path
|
||||
}
|
||||
|
||||
// GeneratePKCEChallenge generates code_verifier and code_challenge for PKCE (RFC 7636)
|
||||
// Returns:
|
||||
// - verifier: Random 128-character string (stored securely, never sent to server)
|
||||
// - challenge: SHA256 hash of verifier, base64url encoded (sent in authorization request)
|
||||
func GeneratePKCEChallenge() (verifier, challenge string, err error) {
|
||||
// Generate random 43-128 character string (we use 128 for maximum entropy)
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
const length = 128
|
||||
|
||||
// Use crypto/rand for secure random generation
|
||||
randomBytes := make([]byte, length)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Convert to allowed charset
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[int(randomBytes[i])%len(charset)]
|
||||
}
|
||||
verifier = string(b)
|
||||
|
||||
// Generate SHA256 hash and base64url encode
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
logger.Debug("[OAuth PKCE] Generated code_verifier and code_challenge")
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// ValidatePKCEChallenge validates that a code_verifier matches the expected code_challenge
|
||||
// Used during testing or debugging
|
||||
func ValidatePKCEChallenge(verifier, challenge string) bool {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
return expectedChallenge == challenge
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationRequest represents the client registration request (RFC 7591)
|
||||
type DynamicClientRegistrationRequest struct {
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationResponse represents the server's response (RFC 7591)
|
||||
type DynamicClientRegistrationResponse struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterDynamicClient performs dynamic client registration with the OAuth provider (RFC 7591)
|
||||
// This allows Bifrost to automatically register as an OAuth client without manual setup.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the registration request
|
||||
// - registrationURL: The registration endpoint (discovered or user-provided)
|
||||
// - req: Client registration details
|
||||
//
|
||||
// Returns client_id and optional client_secret that can be used for OAuth flows.
|
||||
func RegisterDynamicClient(ctx context.Context, registrationURL string, req *DynamicClientRegistrationRequest) (*DynamicClientRegistrationResponse, error) {
|
||||
logger.Debug(fmt.Sprintf("[Dynamic Registration] Registering client at: %s", registrationURL))
|
||||
logger.Debug(fmt.Sprintf("[Dynamic Registration] Client name: %s, Redirect URIs: %v", req.ClientName, req.RedirectURIs))
|
||||
|
||||
// Serialize request
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal registration request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", registrationURL, strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registration request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read registration response: %w", err)
|
||||
}
|
||||
|
||||
// Check status code (201 Created or 200 OK are both valid per RFC 7591)
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
logger.Error(fmt.Sprintf("[Dynamic Registration] Failed with status %d: %s", resp.StatusCode, string(respBody)))
|
||||
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var regResp DynamicClientRegistrationResponse
|
||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registration response: %w", err)
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if regResp.ClientID == "" {
|
||||
return nil, fmt.Errorf("registration response missing client_id")
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("[Dynamic Registration] Successfully registered client_id: %s", regResp.ClientID))
|
||||
if regResp.ClientSecret != "" {
|
||||
logger.Debug("[Dynamic Registration] Client secret provided by server")
|
||||
} else {
|
||||
logger.Debug("[Dynamic Registration] No client secret provided (public client)")
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
9
framework/oauth2/init.go
Normal file
9
framework/oauth2/init.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package oauth2
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
var logger schemas.Logger
|
||||
|
||||
func SetLogger(l schemas.Logger) {
|
||||
logger = l
|
||||
}
|
||||
1110
framework/oauth2/main.go
Normal file
1110
framework/oauth2/main.go
Normal file
File diff suppressed because it is too large
Load Diff
135
framework/oauth2/sync.go
Normal file
135
framework/oauth2/sync.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TokenRefreshWorker manages automatic token refresh for expiring OAuth tokens
|
||||
type TokenRefreshWorker struct {
|
||||
provider *OAuth2Provider
|
||||
refreshInterval time.Duration
|
||||
lookAheadWindow time.Duration // How far ahead to look for expiring tokens
|
||||
stopCh chan struct{}
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewTokenRefreshWorker creates a new token refresh worker
|
||||
func NewTokenRefreshWorker(provider *OAuth2Provider, logger schemas.Logger) *TokenRefreshWorker {
|
||||
if provider.configStore == nil {
|
||||
logger.Warn("config store is nil, skipping token refresh worker")
|
||||
return nil
|
||||
}
|
||||
return &TokenRefreshWorker{
|
||||
provider: provider,
|
||||
refreshInterval: 5 * time.Minute, // Check every 5 minutes
|
||||
lookAheadWindow: 5 * time.Minute, // Refresh tokens expiring in next 5 minutes
|
||||
stopCh: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the token refresh worker in a background goroutine
|
||||
func (w *TokenRefreshWorker) Start(ctx context.Context) {
|
||||
go w.run(ctx)
|
||||
if w.logger != nil {
|
||||
w.logger.Info("Token refresh worker started")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully stops the token refresh worker
|
||||
func (w *TokenRefreshWorker) Stop() {
|
||||
close(w.stopCh)
|
||||
if w.logger != nil {
|
||||
w.logger.Info("Token refresh worker stopped")
|
||||
}
|
||||
}
|
||||
|
||||
// run is the main worker loop
|
||||
func (w *TokenRefreshWorker) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(w.refreshInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
w.refreshExpiredTokens(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
w.refreshExpiredTokens(ctx)
|
||||
case <-w.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshExpiredTokens queries and refreshes tokens that are expiring soon
|
||||
func (w *TokenRefreshWorker) refreshExpiredTokens(ctx context.Context) {
|
||||
expiryThreshold := time.Now().Add(w.lookAheadWindow)
|
||||
|
||||
// Get tokens expiring before the threshold
|
||||
tokens, err := w.provider.configStore.GetExpiringOauthTokens(ctx, expiryThreshold)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to get expiring tokens", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("Found expiring tokens to refresh: %d", len(tokens))
|
||||
}
|
||||
|
||||
// Refresh each expiring token
|
||||
for _, token := range tokens {
|
||||
// Find the oauth_config that references this token
|
||||
oauthConfig, err := w.provider.configStore.GetOauthConfigByTokenID(ctx, token.ID)
|
||||
if err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to find oauth config for token: %s, error: %s", token.ID, err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Warn("No oauth config found for token: %s", token.ID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Attempt to refresh the token
|
||||
if err := w.provider.RefreshAccessToken(ctx, oauthConfig.ID); err != nil {
|
||||
if w.logger != nil {
|
||||
w.logger.Error("Failed to refresh token", "oauth_config_id", oauthConfig.ID, "error", err)
|
||||
}
|
||||
|
||||
// Only mark as expired for permanent auth rejections (e.g. invalid_grant, 401).
|
||||
// Transient failures (DNS, timeout, offline) are skipped — the worker will
|
||||
// retry on the next tick and the connection heals automatically when online.
|
||||
w.provider.markExpiredIfPermanent(ctx, oauthConfig, err)
|
||||
} else {
|
||||
if w.logger != nil {
|
||||
w.logger.Debug("Successfully refreshed token: %s", oauthConfig.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshInterval updates the refresh check interval (for testing)
|
||||
func (w *TokenRefreshWorker) SetRefreshInterval(interval time.Duration) {
|
||||
w.refreshInterval = interval
|
||||
}
|
||||
|
||||
// SetLookAheadWindow updates the look-ahead window for token expiry (for testing)
|
||||
func (w *TokenRefreshWorker) SetLookAheadWindow(window time.Duration) {
|
||||
w.lookAheadWindow = window
|
||||
}
|
||||
310
framework/oauth2/sync_test.go
Normal file
310
framework/oauth2/sync_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// testConfigStore is a minimal in-memory implementation of configstore.ConfigStore
|
||||
// for use in oauth2 tests. Embeds the interface so unneeded methods panic if called.
|
||||
type testConfigStore struct {
|
||||
configstore.ConfigStore
|
||||
|
||||
mu sync.Mutex
|
||||
oauthConfigs map[string]*tables.TableOauthConfig
|
||||
oauthTokens map[string]*tables.TableOauthToken
|
||||
}
|
||||
|
||||
func newTestConfigStore() *testConfigStore {
|
||||
return &testConfigStore{
|
||||
oauthConfigs: make(map[string]*tables.TableOauthConfig),
|
||||
oauthTokens: make(map[string]*tables.TableOauthToken),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetOauthConfigByID(_ context.Context, id string) (*tables.TableOauthConfig, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cfg := s.oauthConfigs[id]
|
||||
if cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return bifrost.Ptr(*cfg), nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetOauthConfigByTokenID(_ context.Context, tokenID string) (*tables.TableOauthConfig, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, cfg := range s.oauthConfigs {
|
||||
if cfg.TokenID != nil && *cfg.TokenID == tokenID {
|
||||
return bifrost.Ptr(*cfg), nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) UpdateOauthConfig(_ context.Context, cfg *tables.TableOauthConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.oauthConfigs[cfg.ID] = bifrost.Ptr(*cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetOauthTokenByID(_ context.Context, id string) (*tables.TableOauthToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
token := s.oauthTokens[id]
|
||||
if token == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return bifrost.Ptr(*token), nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) UpdateOauthToken(_ context.Context, token *tables.TableOauthToken) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.oauthTokens[token.ID] = bifrost.Ptr(*token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testConfigStore) GetExpiringOauthTokens(_ context.Context, before time.Time) ([]*tables.TableOauthToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expiring []*tables.TableOauthToken
|
||||
for _, token := range s.oauthTokens {
|
||||
if token.ExpiresAt.Before(before) {
|
||||
expiring = append(expiring, bifrost.Ptr(*token))
|
||||
}
|
||||
}
|
||||
return expiring, nil
|
||||
}
|
||||
|
||||
// seedFixtures inserts an authorized oauth_config + token pair into the store.
|
||||
// The token expires 1 minute from now so GetExpiringOauthTokens will find it.
|
||||
func seedFixtures(t *testing.T, store *testConfigStore, tokenURL string) (oauthConfigID string) {
|
||||
t.Helper()
|
||||
|
||||
tokenID := "test-token-id"
|
||||
store.oauthTokens[tokenID] = &tables.TableOauthToken{
|
||||
ID: tokenID,
|
||||
AccessToken: "old-access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
TokenType: "bearer",
|
||||
ExpiresAt: time.Now().Add(1 * time.Minute),
|
||||
Scopes: "[]",
|
||||
}
|
||||
|
||||
oauthConfigID = "test-oauth-config-id"
|
||||
store.oauthConfigs[oauthConfigID] = &tables.TableOauthConfig{
|
||||
ID: oauthConfigID,
|
||||
ClientID: "test-client-id",
|
||||
TokenURL: tokenURL,
|
||||
RedirectURI: "http://localhost/callback",
|
||||
Scopes: `["read"]`,
|
||||
Status: "authorized",
|
||||
TokenID: bifrost.Ptr(tokenID),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
return oauthConfigID
|
||||
}
|
||||
|
||||
func newTestWorker(store *testConfigStore) *TokenRefreshWorker {
|
||||
noopLogger := bifrost.NewDefaultLogger(schemas.LogLevelError)
|
||||
provider := NewOAuth2Provider(store, noopLogger)
|
||||
provider.retryBaseDelay = 1 * time.Millisecond // speed up retry backoff in tests
|
||||
return NewTokenRefreshWorker(provider, noopLogger)
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_TransientError_DoesNotMarkExpired(t *testing.T) {
|
||||
// A 503 response from the token server is a transient failure.
|
||||
// The oauth_config must stay "authorized" so the connection can
|
||||
// heal automatically when the server recovers.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "transient server error must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_PermanentError_MarksExpired(t *testing.T) {
|
||||
// A 401 invalid_grant response is a permanent rejection from the auth server.
|
||||
// The oauth_config must be marked "expired" to prompt the user to re-authorize.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Refresh token expired or revoked",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "expired", cfg.Status, "permanent auth rejection must mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_SuccessfulRefresh_UpdatesToken(t *testing.T) {
|
||||
// A successful refresh must update the stored access token and
|
||||
// leave the oauth_config status as "authorized".
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status)
|
||||
|
||||
token, err := store.GetOauthTokenByID(context.Background(), *cfg.TokenID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new-access-token", token.AccessToken)
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_ConnectionRefused_DoesNotMarkExpired(t *testing.T) {
|
||||
// This is the exact failure mode that triggered this fix: the machine goes
|
||||
// offline, DNS fails, and the token endpoint is unreachable. The transport
|
||||
// error (client.Do fails) must not mark the config expired.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
tokenURL := server.URL + "/token"
|
||||
server.Close() // close immediately so all connection attempts are refused
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, tokenURL)
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "connection refused must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_400InvalidGrant_MarksExpired(t *testing.T) {
|
||||
// 400 invalid_grant is the canonical RFC 6749 signal that a refresh token
|
||||
// has been revoked. Must mark the config expired.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "The refresh token has been revoked",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "expired", cfg.Status, "400 invalid_grant must mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_429RateLimit_DoesNotMarkExpired(t *testing.T) {
|
||||
// 429 Too Many Requests is a transient rate limit — not a permanent auth
|
||||
// rejection. Must not mark the config expired.
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "1")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "429 rate limit must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_400InvalidRequest_DoesNotMarkExpired(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing required parameter",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "authorized", cfg.Status, "400 invalid_request must not mark config as expired")
|
||||
}
|
||||
|
||||
func TestTokenRefreshWorker_400UnauthorizedClient_MarksExpired(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized_client",
|
||||
"error_description": "Client is not authorized for this grant type",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
store := newTestConfigStore()
|
||||
oauthConfigID := seedFixtures(t, store, server.URL+"/token")
|
||||
|
||||
worker := newTestWorker(store)
|
||||
worker.refreshExpiredTokens(context.Background())
|
||||
|
||||
cfg, err := store.GetOauthConfigByID(context.Background(), oauthConfigID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "expired", cfg.Status, "400 unauthorized_client must mark config as expired")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user