first commit
This commit is contained in:
298
transports/bifrost-http/server/plugins.go
Normal file
298
transports/bifrost-http/server/plugins.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/compat"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/plugins/logging"
|
||||
"github.com/maximhq/bifrost/plugins/maxim"
|
||||
"github.com/maximhq/bifrost/plugins/otel"
|
||||
"github.com/maximhq/bifrost/plugins/prompts"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/maximhq/bifrost/plugins/telemetry"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/handlers"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// InferPluginTypes determines which interface types a plugin implements
|
||||
func InferPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType {
|
||||
var types []schemas.PluginType
|
||||
if _, ok := plugin.(schemas.LLMPlugin); ok {
|
||||
types = append(types, schemas.PluginTypeLLM)
|
||||
}
|
||||
if _, ok := plugin.(schemas.MCPPlugin); ok {
|
||||
types = append(types, schemas.PluginTypeMCP)
|
||||
}
|
||||
if _, ok := plugin.(schemas.HTTPTransportPlugin); ok {
|
||||
types = append(types, schemas.PluginTypeHTTP)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
// Single-plugin methods used plugin create/update
|
||||
|
||||
// InstantiatePlugin creates a plugin instance but does NOT register it
|
||||
// Registration is done separately via Config.RegisterPlugin()
|
||||
func InstantiatePlugin(ctx context.Context, name string, path *string, pluginConfig any, bifrostConfig *lib.Config) (schemas.BasePlugin, error) {
|
||||
// Custom plugin (has path)
|
||||
if path != nil {
|
||||
return loadCustomPlugin(ctx, path, pluginConfig, bifrostConfig)
|
||||
}
|
||||
|
||||
// Built-in plugin (by name)
|
||||
return loadBuiltinPlugin(ctx, name, pluginConfig, bifrostConfig)
|
||||
}
|
||||
|
||||
// loadBuiltinPlugin instantiates a built-in plugin by name
|
||||
func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifrostConfig *lib.Config) (schemas.BasePlugin, error) {
|
||||
switch name {
|
||||
case telemetry.PluginName:
|
||||
telConfig := &telemetry.Config{
|
||||
CustomLabels: bifrostConfig.ClientConfig.PrometheusLabels,
|
||||
}
|
||||
// Merge push gateway config if provided (e.g., from config file or UI update)
|
||||
if pluginConfig != nil {
|
||||
extraConfig, err := MarshalPluginConfig[telemetry.Config](pluginConfig)
|
||||
if err == nil && extraConfig != nil && extraConfig.PushGateway != nil {
|
||||
telConfig.PushGateway = extraConfig.PushGateway
|
||||
}
|
||||
}
|
||||
return telemetry.Init(telConfig, bifrostConfig.ModelCatalog, logger)
|
||||
|
||||
case prompts.PluginName:
|
||||
return prompts.Init(ctx, bifrostConfig.ConfigStore, logger)
|
||||
|
||||
case logging.PluginName:
|
||||
loggingConfig, err := MarshalPluginConfig[logging.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal logging plugin config: %w", err)
|
||||
}
|
||||
return logging.Init(ctx, loggingConfig, logger, bifrostConfig.LogsStore,
|
||||
bifrostConfig.ModelCatalog, bifrostConfig.MCPCatalog)
|
||||
|
||||
case governance.PluginName:
|
||||
governanceConfig, err := MarshalPluginConfig[governance.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal governance plugin config: %w", err)
|
||||
}
|
||||
inMemoryStore := &GovernanceInMemoryStore{Config: bifrostConfig}
|
||||
return governance.Init(ctx, governanceConfig, logger, bifrostConfig.ConfigStore,
|
||||
bifrostConfig.GovernanceConfig, bifrostConfig.ModelCatalog,
|
||||
bifrostConfig.MCPCatalog, inMemoryStore)
|
||||
|
||||
case maxim.PluginName:
|
||||
maximConfig, err := MarshalPluginConfig[maxim.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal maxim plugin config: %w", err)
|
||||
}
|
||||
return maxim.Init(maximConfig, logger)
|
||||
|
||||
case semanticcache.PluginName:
|
||||
semanticConfig, err := MarshalPluginConfig[semanticcache.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal semantic cache plugin config: %w", err)
|
||||
}
|
||||
return semanticcache.Init(ctx, semanticConfig, logger, bifrostConfig.VectorStore)
|
||||
|
||||
case otel.PluginName:
|
||||
otelConfig, err := MarshalPluginConfig[otel.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal otel plugin config: %w", err)
|
||||
}
|
||||
return otel.Init(ctx, otelConfig, logger, bifrostConfig.ModelCatalog, handlers.GetVersion())
|
||||
|
||||
case compat.PluginName:
|
||||
compatConfig, err := MarshalPluginConfig[compat.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal compat plugin config: %w", err)
|
||||
}
|
||||
return compat.Init(*compatConfig, logger, bifrostConfig.ModelCatalog)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown built-in plugin: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadCustomPlugin loads a plugin from a shared object file
|
||||
func loadCustomPlugin(ctx context.Context, path *string, pluginConfig any, bifrostConfig *lib.Config) (schemas.BasePlugin, error) {
|
||||
logger.Info("loading custom plugin from path %s", *path)
|
||||
|
||||
plugin, err := bifrostConfig.PluginLoader.LoadPlugin(*path, pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load custom plugin: %w", err)
|
||||
}
|
||||
return plugin, nil
|
||||
}
|
||||
|
||||
// LoadPlugins loads the plugins for the server.
|
||||
func (s *BifrostHTTPServer) LoadPlugins(ctx context.Context) error {
|
||||
// Load built-in plugins first (order matters)
|
||||
if err := s.loadBuiltinPlugins(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// Load custom plugins from config
|
||||
if err := s.loadCustomPlugins(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// Sort all plugins by placement group and order
|
||||
s.Config.SortAndRebuildPlugins()
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPluginConfig retrieves a plugin's config from PluginConfigs by name
|
||||
func (s *BifrostHTTPServer) getPluginConfig(name string) *schemas.PluginConfig {
|
||||
for _, cfg := range s.Config.PluginConfigs {
|
||||
if cfg.Name == name {
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadBuiltinPlugins loads required built-in plugins in specific order
|
||||
func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error {
|
||||
builtinPlacement := schemas.Ptr(schemas.PluginPlacementBuiltin)
|
||||
|
||||
// 1. Telemetry (always first - tracks everything)
|
||||
if err := s.registerPluginWithStatus(ctx, telemetry.PluginName, nil, nil, true); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(telemetry.PluginName, builtinPlacement, schemas.Ptr(1))
|
||||
|
||||
// 2. Prompts (requires config store for prompt repository; disabled in enterprise)
|
||||
if s.Config.ConfigStore != nil && ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil {
|
||||
s.registerPluginWithStatus(ctx, prompts.PluginName, nil, nil, false)
|
||||
} else {
|
||||
s.markPluginDisabled(prompts.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(prompts.PluginName, builtinPlacement, schemas.Ptr(2))
|
||||
|
||||
// 3. Logging (if enabled)
|
||||
if (s.Config.ClientConfig.EnableLogging == nil || *s.Config.ClientConfig.EnableLogging) && s.Config.LogsStore != nil {
|
||||
config := &logging.Config{
|
||||
DisableContentLogging: &s.Config.ClientConfig.DisableContentLogging,
|
||||
LoggingHeaders: &s.Config.ClientConfig.LoggingHeaders,
|
||||
}
|
||||
s.registerPluginWithStatus(ctx, logging.PluginName, nil, config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(logging.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(logging.PluginName, builtinPlacement, schemas.Ptr(3))
|
||||
|
||||
// 4. Governance (if enabled and not enterprise)
|
||||
if ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil {
|
||||
config := &governance.Config{
|
||||
IsVkMandatory: &s.Config.ClientConfig.EnforceAuthOnInference,
|
||||
RequiredHeaders: &s.Config.ClientConfig.RequiredHeaders,
|
||||
DisableAutoToolInject: &s.Config.ClientConfig.MCPDisableAutoToolInject,
|
||||
RoutingChainMaxDepth: &s.Config.ClientConfig.RoutingChainMaxDepth,
|
||||
}
|
||||
s.registerPluginWithStatus(ctx, governance.PluginName, nil, config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(governance.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(governance.PluginName, builtinPlacement, schemas.Ptr(4))
|
||||
|
||||
// 5. OTEL (if configured in PluginConfigs)
|
||||
otelConfig := s.getPluginConfig(otel.PluginName)
|
||||
if otelConfig != nil && otelConfig.Enabled {
|
||||
s.registerPluginWithStatus(ctx, otel.PluginName, nil, otelConfig.Config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(otel.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(otel.PluginName, builtinPlacement, schemas.Ptr(5))
|
||||
|
||||
// 6. Semantic Cache (if configured in PluginConfigs)
|
||||
semanticCacheConfig := s.getPluginConfig(semanticcache.PluginName)
|
||||
if semanticCacheConfig != nil && semanticCacheConfig.Enabled {
|
||||
s.registerPluginWithStatus(ctx, semanticcache.PluginName, nil, semanticCacheConfig.Config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(semanticcache.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6))
|
||||
|
||||
// 7. Compat (if any compat feature is enabled in ClientConfig)
|
||||
cc := s.Config.ClientConfig.Compat
|
||||
compatCfg := &compat.Config{
|
||||
ConvertTextToChat: cc.ConvertTextToChat,
|
||||
ConvertChatToResponses: cc.ConvertChatToResponses,
|
||||
ShouldDropParams: cc.ShouldDropParams,
|
||||
ShouldConvertParams: cc.ShouldConvertParams,
|
||||
}
|
||||
s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatCfg, false)
|
||||
s.Config.SetPluginOrderInfo(compat.PluginName, builtinPlacement, schemas.Ptr(7))
|
||||
|
||||
// 8. Maxim (if configured in PluginConfigs)
|
||||
maximConfig := s.getPluginConfig(maxim.PluginName)
|
||||
if maximConfig != nil && maximConfig.Enabled {
|
||||
s.registerPluginWithStatus(ctx, maxim.PluginName, nil, maximConfig.Config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(maxim.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(maxim.PluginName, builtinPlacement, schemas.Ptr(8))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCustomPlugins loads plugins from PluginConfigs
|
||||
func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error {
|
||||
for _, cfg := range s.Config.PluginConfigs {
|
||||
// Skip built-ins (already loaded)
|
||||
if lib.IsBuiltinPlugin(cfg.Name) {
|
||||
continue
|
||||
}
|
||||
// Handle disabled plugins
|
||||
if !cfg.Enabled {
|
||||
// For custom plugins with a path, verify to get the real plugin name
|
||||
if cfg.Path != nil {
|
||||
pluginName, err := s.Config.PluginLoader.VerifyBasePlugin(*cfg.Path)
|
||||
if err != nil {
|
||||
logger.Error("failed to verify disabled plugin %s: %v", cfg.Name, err)
|
||||
continue
|
||||
}
|
||||
// Store plugin status without instantiating (no Init() call, no resource usage)
|
||||
// Note: We can't determine types without instantiating, so pass empty slice
|
||||
s.Config.UpdatePluginOverallStatus(pluginName, cfg.Name, schemas.PluginStatusDisabled,
|
||||
[]string{fmt.Sprintf("plugin %s is disabled", cfg.Name)}, []schemas.PluginType{})
|
||||
} else {
|
||||
// Built-in plugin - use cfg.Name directly
|
||||
s.Config.UpdatePluginOverallStatus(cfg.Name, cfg.Name, schemas.PluginStatusDisabled,
|
||||
[]string{fmt.Sprintf("plugin %s is disabled", cfg.Name)}, []schemas.PluginType{})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Plugin is enabled - instantiate it
|
||||
plugin, err := InstantiatePlugin(ctx, cfg.Name, cfg.Path, cfg.Config, s.Config)
|
||||
if err != nil {
|
||||
// Skip enterprise plugins silently
|
||||
if slices.Contains(enterprisePlugins, cfg.Name) {
|
||||
continue
|
||||
}
|
||||
logger.Error("failed to load plugin %s: %v", cfg.Name, err)
|
||||
// Use cfg.Name since plugin may be nil when InstantiatePlugin returns an error
|
||||
s.Config.UpdatePluginOverallStatus(cfg.Name, cfg.Name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("error loading plugin %s: %v", cfg.Name, err)}, []schemas.PluginType{})
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure plugin is not nil before using it (defensive check)
|
||||
if plugin == nil {
|
||||
logger.Error("plugin %s instantiated but returned nil", cfg.Name)
|
||||
s.Config.UpdatePluginOverallStatus(cfg.Name, cfg.Name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("plugin %s instantiated but returned nil", cfg.Name)}, []schemas.PluginType{})
|
||||
continue
|
||||
}
|
||||
|
||||
// Register enabled plugin and mark as active
|
||||
s.Config.ReloadPlugin(plugin)
|
||||
s.Config.SetPluginOrderInfo(plugin.GetName(), cfg.Placement, cfg.Order)
|
||||
s.Config.UpdatePluginOverallStatus(plugin.GetName(), cfg.Name, schemas.PluginStatusActive,
|
||||
[]string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, InferPluginTypes(plugin))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
1544
transports/bifrost-http/server/server.go
Normal file
1544
transports/bifrost-http/server/server.go
Normal file
File diff suppressed because it is too large
Load Diff
393
transports/bifrost-http/server/server_test.go
Normal file
393
transports/bifrost-http/server/server_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// TestConfig is a sample config struct for testing
|
||||
type TestConfig struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type updateStatusOnlyConfigStore struct {
|
||||
configstore.ConfigStore
|
||||
calls []schemas.KeyStatus
|
||||
}
|
||||
|
||||
type noopTestLogger struct{}
|
||||
|
||||
func (noopTestLogger) Debug(string, ...any) {}
|
||||
func (noopTestLogger) Info(string, ...any) {}
|
||||
func (noopTestLogger) Warn(string, ...any) {}
|
||||
func (noopTestLogger) Error(string, ...any) {}
|
||||
func (noopTestLogger) Fatal(string, ...any) {}
|
||||
func (noopTestLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopTestLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func (s *updateStatusOnlyConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error {
|
||||
s.calls = append(s.calls, schemas.KeyStatus{
|
||||
Provider: provider,
|
||||
KeyID: keyID,
|
||||
Status: schemas.KeyStatusType(status),
|
||||
Error: &schemas.BifrostError{Error: &schemas.ErrorField{Message: errorMsg}},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateKeyStatus_KeylessProviderUpdatesProviderStatusInMemory(t *testing.T) {
|
||||
prevLogger := logger
|
||||
logger = noopTestLogger{}
|
||||
defer func() { logger = prevLogger }()
|
||||
|
||||
store := &updateStatusOnlyConfigStore{}
|
||||
server := &BifrostHTTPServer{
|
||||
Config: &lib.Config{
|
||||
ConfigStore: store,
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
"mock-openai": {
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{IsKeyLess: true},
|
||||
Status: "unknown",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server.updateKeyStatus(context.Background(), []schemas.KeyStatus{{
|
||||
Provider: "mock-openai",
|
||||
KeyID: "",
|
||||
Status: schemas.KeyStatusListModelsFailed,
|
||||
Error: &schemas.BifrostError{Error: &schemas.ErrorField{Message: "preview missing model"}},
|
||||
}})
|
||||
|
||||
provider := server.Config.Providers["mock-openai"]
|
||||
if provider.Status != string(schemas.KeyStatusListModelsFailed) {
|
||||
t.Fatalf("expected provider status %q, got %q", schemas.KeyStatusListModelsFailed, provider.Status)
|
||||
}
|
||||
if provider.Description != "preview missing model" {
|
||||
t.Fatalf("expected provider description to be updated, got %q", provider.Description)
|
||||
}
|
||||
if len(store.calls) != 1 {
|
||||
t.Fatalf("expected one status update call, got %d", len(store.calls))
|
||||
}
|
||||
if store.calls[0].Provider != "mock-openai" || store.calls[0].KeyID != "" {
|
||||
t.Fatalf("expected provider-level status update, got provider=%q keyID=%q", store.calls[0].Provider, store.calls[0].KeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateKeyStatus_EmptyKeyIDDoesNotOverwriteKeyedProviderStatus(t *testing.T) {
|
||||
prevLogger := logger
|
||||
logger = noopTestLogger{}
|
||||
defer func() { logger = prevLogger }()
|
||||
|
||||
store := &updateStatusOnlyConfigStore{}
|
||||
server := &BifrostHTTPServer{
|
||||
Config: &lib.Config{
|
||||
ConfigStore: store,
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
"openai": {
|
||||
Keys: []schemas.Key{{ID: "key-1"}},
|
||||
Status: "healthy",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server.updateKeyStatus(context.Background(), []schemas.KeyStatus{{
|
||||
Provider: "openai",
|
||||
KeyID: "",
|
||||
Status: schemas.KeyStatusListModelsFailed,
|
||||
Error: &schemas.BifrostError{Error: &schemas.ErrorField{Message: "malformed status"}},
|
||||
}})
|
||||
|
||||
provider := server.Config.Providers["openai"]
|
||||
if provider.Status != "healthy" {
|
||||
t.Fatalf("expected keyed provider status to remain unchanged, got %q", provider.Status)
|
||||
}
|
||||
if provider.Description != "" {
|
||||
t.Fatalf("expected keyed provider description to remain unchanged, got %q", provider.Description)
|
||||
}
|
||||
if len(store.calls) != 1 {
|
||||
t.Fatalf("expected one status update call, got %d", len(store.calls))
|
||||
}
|
||||
if store.calls[0].Provider != "openai" || store.calls[0].KeyID != "" {
|
||||
t.Fatalf("expected DB status update to retain empty key ID, got provider=%q keyID=%q", store.calls[0].Provider, store.calls[0].KeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithPointerType(t *testing.T) {
|
||||
// Test case 1: source is already *T
|
||||
expected := &TestConfig{
|
||||
Name: "test-plugin",
|
||||
Enabled: true,
|
||||
Count: 42,
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](expected)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected same pointer, got different pointer")
|
||||
}
|
||||
|
||||
if result.Name != expected.Name {
|
||||
t.Errorf("Expected Name=%s, got %s", expected.Name, result.Name)
|
||||
}
|
||||
if result.Enabled != expected.Enabled {
|
||||
t.Errorf("Expected Enabled=%v, got %v", expected.Enabled, result.Enabled)
|
||||
}
|
||||
if result.Count != expected.Count {
|
||||
t.Errorf("Expected Count=%d, got %d", expected.Count, result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithMap(t *testing.T) {
|
||||
// Test case 2: source is map[string]any
|
||||
configMap := map[string]any{
|
||||
"name": "test-plugin",
|
||||
"enabled": true,
|
||||
"count": 42,
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Name != "test-plugin" {
|
||||
t.Errorf("Expected Name=test-plugin, got %s", result.Name)
|
||||
}
|
||||
if result.Enabled != true {
|
||||
t.Errorf("Expected Enabled=true, got %v", result.Enabled)
|
||||
}
|
||||
if result.Count != 42 {
|
||||
t.Errorf("Expected Count=42, got %d", result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithString(t *testing.T) {
|
||||
// Test case 3: source is string (JSON)
|
||||
configStr := `{"name":"test-plugin","enabled":true,"count":42}`
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Name != "test-plugin" {
|
||||
t.Errorf("Expected Name=test-plugin, got %s", result.Name)
|
||||
}
|
||||
if result.Enabled != true {
|
||||
t.Errorf("Expected Enabled=true, got %v", result.Enabled)
|
||||
}
|
||||
if result.Count != 42 {
|
||||
t.Errorf("Expected Count=42, got %d", result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithInvalidType(t *testing.T) {
|
||||
// Test case 4: source is invalid type (should return error)
|
||||
invalidSource := 12345
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](invalidSource)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid type, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for invalid type, got %v", result)
|
||||
}
|
||||
|
||||
expectedError := "invalid config type"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithInvalidJSONString(t *testing.T) {
|
||||
// Test case 5: source is string but invalid JSON
|
||||
invalidJSON := `{"name":"test-plugin","enabled":true,count:42}` // missing quotes around count
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](invalidJSON)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid JSON, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for invalid JSON, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithInvalidMapData(t *testing.T) {
|
||||
// Test case 6: source is map but contains invalid data types
|
||||
configMap := map[string]any{
|
||||
"name": "test-plugin",
|
||||
"enabled": "not-a-boolean", // wrong type
|
||||
"count": 42,
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configMap)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid map data, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for invalid map data, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithEmptyMap(t *testing.T) {
|
||||
// Test case 7: source is empty map (should work, return zero values)
|
||||
configMap := map[string]any{}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error for empty map, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
// All fields should have zero values
|
||||
if result.Name != "" {
|
||||
t.Errorf("Expected empty Name, got %s", result.Name)
|
||||
}
|
||||
if result.Enabled != false {
|
||||
t.Errorf("Expected Enabled=false, got %v", result.Enabled)
|
||||
}
|
||||
if result.Count != 0 {
|
||||
t.Errorf("Expected Count=0, got %d", result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithEmptyString(t *testing.T) {
|
||||
// Test case 8: source is empty string (should fail as invalid JSON)
|
||||
configStr := ""
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configStr)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty string, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for empty string, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithNil(t *testing.T) {
|
||||
// Test case 9: source is nil (should return error as invalid type)
|
||||
result, err := MarshalPluginConfig[TestConfig](nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nil source, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for nil source, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMarshalPluginConfig_WithPointerType(b *testing.B) {
|
||||
config := &TestConfig{
|
||||
Name: "test-plugin",
|
||||
Enabled: true,
|
||||
Count: 42,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MarshalPluginConfig[TestConfig](config)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalPluginConfig_WithMap(b *testing.B) {
|
||||
configMap := map[string]any{
|
||||
"name": "test-plugin",
|
||||
"enabled": true,
|
||||
"count": 42,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MarshalPluginConfig[TestConfig](configMap)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalPluginConfig_WithString(b *testing.B) {
|
||||
configStr := `{"name":"test-plugin","enabled":true,"count":42}`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MarshalPluginConfig[TestConfig](configStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Complex config for additional testing
|
||||
type ComplexConfig struct {
|
||||
Settings map[string]string `json:"settings"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
Nested *TestConfig `json:"nested"`
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithComplexType(t *testing.T) {
|
||||
// Test with a more complex nested structure
|
||||
configMap := map[string]any{
|
||||
"settings": map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
"tags": []any{"tag1", "tag2", "tag3"},
|
||||
"metadata": map[string]any{
|
||||
"version": "1.0.0",
|
||||
"author": "test",
|
||||
},
|
||||
"nested": map[string]any{
|
||||
"name": "nested-config",
|
||||
"enabled": true,
|
||||
"count": 10,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[ComplexConfig](configMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
if len(result.Settings) != 2 {
|
||||
t.Errorf("Expected 2 settings, got %d", len(result.Settings))
|
||||
}
|
||||
if len(result.Tags) != 3 {
|
||||
t.Errorf("Expected 3 tags, got %d", len(result.Tags))
|
||||
}
|
||||
if result.Nested == nil {
|
||||
t.Fatal("Expected non-nil nested config")
|
||||
}
|
||||
if result.Nested.Name != "nested-config" {
|
||||
t.Errorf("Expected nested name=nested-config, got %s", result.Nested.Name)
|
||||
}
|
||||
}
|
||||
208
transports/bifrost-http/server/utils.go
Normal file
208
transports/bifrost-http/server/utils.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// GetDefaultConfigDir returns the OS-specific default configuration directory for Bifrost.
|
||||
// This follows standard conventions:
|
||||
// - Linux/macOS: ~/.config/bifrost
|
||||
// - Windows: %APPDATA%\bifrost
|
||||
// - If appDir is provided (non-empty), it returns that instead
|
||||
func GetDefaultConfigDir(appDir string) string {
|
||||
// If appDir is provided, use it directly
|
||||
if appDir != "" {
|
||||
return appDir
|
||||
}
|
||||
|
||||
// Get OS-specific config directory
|
||||
var configDir string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
// Windows: %APPDATA%\bifrost
|
||||
if appData := os.Getenv("APPDATA"); appData != "" {
|
||||
configDir = filepath.Join(appData, "bifrost")
|
||||
} else {
|
||||
// Fallback to user home directory
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
configDir = filepath.Join(homeDir, "AppData", "Roaming", "bifrost")
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Linux, macOS and other Unix-like systems: ~/.config/bifrost
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
configDir = filepath.Join(homeDir, ".config", "bifrost")
|
||||
}
|
||||
}
|
||||
|
||||
// If we couldn't determine the config directory, fall back to current directory
|
||||
if configDir == "" {
|
||||
configDir = "./bifrost-data"
|
||||
}
|
||||
|
||||
return configDir
|
||||
}
|
||||
|
||||
// registerPluginWithStatus instantiates, registers, and updates status for a plugin (used by builtin plugins)
|
||||
func (s *BifrostHTTPServer) registerPluginWithStatus(ctx context.Context, name string, path *string, config any, failOnError bool) error {
|
||||
plugin, err := InstantiatePlugin(ctx, name, path, config, s.Config)
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize %s plugin: %v", name, err)
|
||||
// Use name since plugin may be nil when InstantiatePlugin returns an error
|
||||
s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("error initializing %s plugin: %v", name, err)}, []schemas.PluginType{})
|
||||
if failOnError {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure plugin is not nil before using it (defensive check)
|
||||
if plugin == nil {
|
||||
logger.Error("plugin %s instantiated but returned nil", name)
|
||||
s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("plugin %s instantiated but returned nil", name)}, []schemas.PluginType{})
|
||||
if failOnError {
|
||||
return fmt.Errorf("plugin %s instantiated but returned nil", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.Config.ReloadPlugin(plugin)
|
||||
s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusActive,
|
||||
[]string{fmt.Sprintf("%s plugin initialized successfully", name)}, InferPluginTypes(plugin))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectObservabilityPlugins gathers all loaded plugins that implement ObservabilityPlugin interface
|
||||
func (s *BifrostHTTPServer) CollectObservabilityPlugins() []schemas.ObservabilityPlugin {
|
||||
var observabilityPlugins []schemas.ObservabilityPlugin
|
||||
|
||||
// Check LLM plugins
|
||||
for _, plugin := range s.Config.GetLoadedLLMPlugins() {
|
||||
if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok {
|
||||
observabilityPlugins = append(observabilityPlugins, observabilityPlugin)
|
||||
}
|
||||
}
|
||||
|
||||
// Check MCP plugins
|
||||
for _, plugin := range s.Config.GetLoadedMCPPlugins() {
|
||||
if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok {
|
||||
observabilityPlugins = append(observabilityPlugins, observabilityPlugin)
|
||||
}
|
||||
}
|
||||
|
||||
return observabilityPlugins
|
||||
}
|
||||
|
||||
// MarshalPluginConfig marshals the plugin configuration
|
||||
func MarshalPluginConfig[T any](source any) (*T, error) {
|
||||
// If its a *T, then we will confirm
|
||||
if config, ok := source.(*T); ok {
|
||||
return config, nil
|
||||
}
|
||||
// Initialize a new instance for unmarshaling
|
||||
config := new(T)
|
||||
// If its a map[string]any, then we will JSON parse and confirm
|
||||
if configMap, ok := source.(map[string]any); ok {
|
||||
configString, err := sonic.Marshal(configMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := sonic.Unmarshal([]byte(configString), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
// If its a string, then we will JSON parse and confirm
|
||||
if configStr, ok := source.(string); ok {
|
||||
if err := sonic.Unmarshal([]byte(configStr), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid config type")
|
||||
}
|
||||
|
||||
// updateKeyStatus updates the model discovery status for keys or providers based on key statuses.
|
||||
// For keyed providers: updates individual key status
|
||||
// For keyless providers: updates provider-level status
|
||||
func (s *BifrostHTTPServer) updateKeyStatus(
|
||||
ctx context.Context,
|
||||
keyStatuses []schemas.KeyStatus,
|
||||
) {
|
||||
if s.Config == nil || s.Config.ConfigStore == nil || len(keyStatuses) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Update each key/provider status individually
|
||||
for _, ks := range keyStatuses {
|
||||
errorMsg := ""
|
||||
if ks.Error != nil && ks.Error.Error != nil {
|
||||
errorMsg = ks.Error.Error.Message
|
||||
}
|
||||
|
||||
if err := s.Config.ConfigStore.UpdateStatus(ctx, ks.Provider, ks.KeyID, string(ks.Status), errorMsg); err != nil {
|
||||
target := ks.KeyID
|
||||
if target == "" {
|
||||
target = string(ks.Provider)
|
||||
}
|
||||
logger.Error("failed to update model discovery status for %s: %v", target, err)
|
||||
continue // Skip in-memory update if DB update failed
|
||||
}
|
||||
|
||||
s.Config.Mu.Lock()
|
||||
|
||||
providerConfig, exists := s.Config.Providers[ks.Provider]
|
||||
if !exists {
|
||||
s.Config.Mu.Unlock()
|
||||
logger.Warn("provider %s not found in memory during status update", ks.Provider)
|
||||
continue
|
||||
}
|
||||
|
||||
isKeylessProvider := providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess
|
||||
|
||||
if ks.KeyID == "" {
|
||||
if !isKeylessProvider {
|
||||
logger.Warn("received provider-level status update for keyed provider %s; skipping in-memory update", ks.Provider)
|
||||
s.Config.Mu.Unlock()
|
||||
continue
|
||||
}
|
||||
providerConfig.Status = string(ks.Status)
|
||||
providerConfig.Description = errorMsg
|
||||
s.Config.Providers[ks.Provider] = providerConfig
|
||||
logger.Debug("updated in-memory status for keyless provider %s", ks.Provider)
|
||||
s.Config.Mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Find and update the specific key in the Keys slice
|
||||
updated := false
|
||||
for i := range providerConfig.Keys {
|
||||
if providerConfig.Keys[i].ID == ks.KeyID {
|
||||
// Update Status and Description fields
|
||||
providerConfig.Keys[i].Status = ks.Status
|
||||
providerConfig.Keys[i].Description = errorMsg
|
||||
updated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if updated {
|
||||
// Write the modified config back to the map
|
||||
s.Config.Providers[ks.Provider] = providerConfig
|
||||
logger.Debug("updated in-memory status for key %s of provider %s", ks.KeyID, ks.Provider)
|
||||
} else {
|
||||
logger.Warn("key %s not found in provider %s during in-memory update", ks.KeyID, ks.Provider)
|
||||
}
|
||||
|
||||
s.Config.Mu.Unlock()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user