first commit
This commit is contained in:
16
framework/plugins/loader.go
Normal file
16
framework/plugins/loader.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package plugins
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
// PluginLoader is the contract for a plugin loader
|
||||
type PluginLoader interface {
|
||||
// LoadPlugin loads a generic plugin from the given path with the provided config
|
||||
// Returns a BasePlugin that can be type-asserted to specific plugin interfaces
|
||||
LoadPlugin(path string, config any) (schemas.BasePlugin, error)
|
||||
|
||||
// VerifyBasePlugin verifies a plugin at the given path
|
||||
// Returns the name of the plugin or an empty string if the plugin is invalid
|
||||
// Returns an error if the plugin is invalid
|
||||
// This method is used to verify that the plugin is a valid base plugin and has the required symbols
|
||||
VerifyBasePlugin(path string) (string, error)
|
||||
}
|
||||
162
framework/plugins/main.go
Normal file
162
framework/plugins/main.go
Normal file
@@ -0,0 +1,162 @@
|
||||
// Package plugins provides a framework for dynamically loading and managing plugins
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// PluginConfig is the generic configuration for any plugin type
|
||||
// Plugin types are automatically detected based on implemented interfaces
|
||||
type PluginConfig struct {
|
||||
Path string `json:"path"`
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Config any `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// Config is the configuration for the plugins framework
|
||||
type Config struct {
|
||||
// Plugins is the unified configuration for all plugin types
|
||||
Plugins []PluginConfig `json:"plugins"`
|
||||
}
|
||||
|
||||
// AsLLMPlugin checks if a base plugin implements LLMPlugin and actually has LLM hooks.
|
||||
// For DynamicPlugin, it checks if the hook function pointers are not nil.
|
||||
// Returns nil if the plugin does not implement the interface or has no LLM hooks.
|
||||
func AsLLMPlugin(plugin schemas.BasePlugin) schemas.LLMPlugin {
|
||||
// Check if it's a DynamicPlugin first
|
||||
if dp, ok := plugin.(*DynamicPlugin); ok {
|
||||
// Only return as LLMPlugin if it actually has LLM hooks
|
||||
if dp.preLLMHook != nil || dp.postLLMHook != nil {
|
||||
return dp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// For non-DynamicPlugin types, use normal type assertion
|
||||
if llmPlugin, ok := plugin.(schemas.LLMPlugin); ok {
|
||||
return llmPlugin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AsMCPPlugin checks if a base plugin implements MCPPlugin and actually has MCP hooks.
|
||||
// For DynamicPlugin, it checks if the hook function pointers are not nil.
|
||||
// Returns nil if the plugin does not implement the interface or has no MCP hooks.
|
||||
func AsMCPPlugin(plugin schemas.BasePlugin) schemas.MCPPlugin {
|
||||
// Check if it's a DynamicPlugin first
|
||||
if dp, ok := plugin.(*DynamicPlugin); ok {
|
||||
// Only return as MCPPlugin if it actually has MCP hooks
|
||||
if dp.preMCPHook != nil || dp.postMCPHook != nil {
|
||||
return dp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// For non-DynamicPlugin types, use normal type assertion
|
||||
if mcpPlugin, ok := plugin.(schemas.MCPPlugin); ok {
|
||||
return mcpPlugin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AsHTTPTransportPlugin checks if a base plugin implements HTTPTransportPlugin and actually has HTTP transport hooks.
|
||||
// For DynamicPlugin, it checks if the hook function pointers are not nil.
|
||||
// Returns nil if the plugin does not implement the interface or has no HTTP transport hooks.
|
||||
func AsHTTPTransportPlugin(plugin schemas.BasePlugin) schemas.HTTPTransportPlugin {
|
||||
// Check if it's a DynamicPlugin first
|
||||
if dp, ok := plugin.(*DynamicPlugin); ok {
|
||||
// Only return as HTTPTransportPlugin if it actually has HTTP transport hooks
|
||||
if dp.httpTransportPreHook != nil || dp.httpTransportPostHook != nil {
|
||||
return dp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// For non-DynamicPlugin types, use normal type assertion
|
||||
if httpPlugin, ok := plugin.(schemas.HTTPTransportPlugin); ok {
|
||||
return httpPlugin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AsObservabilityPlugin checks if a base plugin implements ObservabilityPlugin and actually has observability hooks.
|
||||
// For DynamicPlugin, it checks if the hook function pointer is not nil.
|
||||
// Returns nil if the plugin does not implement the interface or has no observability hooks.
|
||||
func AsObservabilityPlugin(plugin schemas.BasePlugin) schemas.ObservabilityPlugin {
|
||||
// Check if it's a DynamicPlugin first
|
||||
if dp, ok := plugin.(*DynamicPlugin); ok {
|
||||
// Only return as ObservabilityPlugin if it actually has the Inject hook
|
||||
if dp.inject != nil {
|
||||
return dp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// For non-DynamicPlugin types, use normal type assertion
|
||||
if obsPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok {
|
||||
return obsPlugin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadPlugins loads all plugins from the config
|
||||
func LoadPlugins(loader PluginLoader, config *Config) ([]schemas.BasePlugin, error) {
|
||||
plugins := []schemas.BasePlugin{}
|
||||
if config == nil {
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
for _, pc := range config.Plugins {
|
||||
if !pc.Enabled {
|
||||
continue
|
||||
}
|
||||
plugin, err := loader.LoadPlugin(pc.Path, pc.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
// FilterLLMPlugins filters a list of BasePlugins to only include those implementing LLMPlugin
|
||||
func FilterLLMPlugins(plugins []schemas.BasePlugin) []schemas.LLMPlugin {
|
||||
result := []schemas.LLMPlugin{}
|
||||
for _, p := range plugins {
|
||||
if llmPlugin := AsLLMPlugin(p); llmPlugin != nil {
|
||||
result = append(result, llmPlugin)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterMCPPlugins filters a list of BasePlugins to only include those implementing MCPPlugin
|
||||
func FilterMCPPlugins(plugins []schemas.BasePlugin) []schemas.MCPPlugin {
|
||||
result := []schemas.MCPPlugin{}
|
||||
for _, p := range plugins {
|
||||
if mcpPlugin := AsMCPPlugin(p); mcpPlugin != nil {
|
||||
result = append(result, mcpPlugin)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterHTTPTransportPlugins filters a list of BasePlugins to only include those implementing HTTPTransportPlugin
|
||||
func FilterHTTPTransportPlugins(plugins []schemas.BasePlugin) []schemas.HTTPTransportPlugin {
|
||||
result := []schemas.HTTPTransportPlugin{}
|
||||
for _, p := range plugins {
|
||||
if httpPlugin := AsHTTPTransportPlugin(p); httpPlugin != nil {
|
||||
result = append(result, httpPlugin)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterObservabilityPlugins filters a list of BasePlugins to only include those implementing ObservabilityPlugin
|
||||
func FilterObservabilityPlugins(plugins []schemas.BasePlugin) []schemas.ObservabilityPlugin {
|
||||
result := []schemas.ObservabilityPlugin{}
|
||||
for _, p := range plugins {
|
||||
if obsPlugin := AsObservabilityPlugin(p); obsPlugin != nil {
|
||||
result = append(result, obsPlugin)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
6
framework/plugins/race_disabled.go
Normal file
6
framework/plugins/race_disabled.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build !race
|
||||
|
||||
package plugins
|
||||
|
||||
// raceEnabled indicates if the binary was built with race detection
|
||||
const raceEnabled = false
|
||||
6
framework/plugins/race_enabled.go
Normal file
6
framework/plugins/race_enabled.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build race
|
||||
|
||||
package plugins
|
||||
|
||||
// raceEnabled indicates if the binary was built with race detection
|
||||
const raceEnabled = true
|
||||
175
framework/plugins/soloader.go
Normal file
175
framework/plugins/soloader.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"plugin"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// SharedObjectPluginLoader is the loader for shared object plugins
|
||||
type SharedObjectPluginLoader struct{}
|
||||
|
||||
func openPlugin(dp *DynamicPlugin) (*plugin.Plugin, error) {
|
||||
// Checking if path is URL or file path
|
||||
if strings.HasPrefix(dp.Path, "http") {
|
||||
// Download the file
|
||||
tempPath, err := DownloadPlugin(dp.Path, ".so")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dp.Path = tempPath
|
||||
}
|
||||
pluginObj, err := plugin.Open(dp.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dp.plugin = pluginObj
|
||||
return pluginObj, nil
|
||||
}
|
||||
|
||||
// LoadPlugin loads a generic plugin from a shared object file
|
||||
// It uses optional symbol lookup - only GetName and Cleanup are required
|
||||
// All other hook methods are optional and stored as nil if not implemented
|
||||
func (l *SharedObjectPluginLoader) LoadPlugin(path string, config any) (schemas.BasePlugin, error) {
|
||||
dp := &DynamicPlugin{
|
||||
Path: path,
|
||||
}
|
||||
|
||||
pluginObj, err := openPlugin(dp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Optional Init method
|
||||
if initSym, err := pluginObj.Lookup("Init"); err == nil {
|
||||
if initFunc, ok := initSym.(func(config any) error); ok {
|
||||
if err := initFunc(config); err != nil {
|
||||
return nil, fmt.Errorf("plugin Init failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to cast Init to func(config any) error")
|
||||
}
|
||||
}
|
||||
|
||||
// Required: GetName
|
||||
getNameSym, err := pluginObj.Lookup("GetName")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("required symbol GetName not found: %w", err)
|
||||
}
|
||||
var ok bool
|
||||
if dp.getName, ok = getNameSym.(func() string); !ok {
|
||||
return nil, fmt.Errorf("failed to cast GetName to func() string\nSee docs for more information: https://docs.getbifrost.ai/plugins/writing-go-plugin")
|
||||
}
|
||||
|
||||
// Required: Cleanup
|
||||
cleanupSym, err := pluginObj.Lookup("Cleanup")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("required symbol Cleanup not found: %w", err)
|
||||
}
|
||||
if dp.cleanup, ok = cleanupSym.(func() error); !ok {
|
||||
return nil, fmt.Errorf("failed to cast Cleanup to func() error\nSee docs for more information: https://docs.getbifrost.ai/plugins/writing-go-plugin")
|
||||
}
|
||||
|
||||
// Optional: HTTPTransportPreHook
|
||||
if sym, err := pluginObj.Lookup("HTTPTransportPreHook"); err == nil {
|
||||
if dp.httpTransportPreHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast HTTPTransportPreHook to expected signature")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: HTTPTransportPostHook
|
||||
if sym, err := pluginObj.Lookup("HTTPTransportPostHook"); err == nil {
|
||||
if dp.httpTransportPostHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error); !ok {
|
||||
return nil, fmt.Errorf("failed to cast HTTPTransportPostHook to expected signature")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: HTTPTransportStreamChunkHook
|
||||
if sym, err := pluginObj.Lookup("HTTPTransportStreamChunkHook"); err == nil {
|
||||
if dp.httpTransportStreamChunkHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast HTTPTransportStreamChunkHook to expected signature")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: PreLLMHook (with backward compatibility for legacy PreHook)
|
||||
if sym, err := pluginObj.Lookup("PreLLMHook"); err == nil {
|
||||
if dp.preLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast PreLLMHook to expected signature")
|
||||
}
|
||||
} else if sym, err := pluginObj.Lookup("PreHook"); err == nil {
|
||||
// Legacy backward compatibility (v1.3.x): treat PreHook as PreLLMHook
|
||||
if dp.preLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast PreHook to expected signature (legacy backward compatibility)")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: PostLLMHook (with backward compatibility for legacy PostHook)
|
||||
if sym, err := pluginObj.Lookup("PostLLMHook"); err == nil {
|
||||
if dp.postLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast PostLLMHook to expected signature")
|
||||
}
|
||||
} else if sym, err := pluginObj.Lookup("PostHook"); err == nil {
|
||||
// Legacy backward compatibility (v1.3.x): treat PostHook as PostLLMHook
|
||||
if dp.postLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast PostHook to expected signature (legacy backward compatibility)")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: PreMCPHook
|
||||
if sym, err := pluginObj.Lookup("PreMCPHook"); err == nil {
|
||||
if dp.preMCPHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast PreMCPHook to expected signature")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: PostMCPHook
|
||||
if sym, err := pluginObj.Lookup("PostMCPHook"); err == nil {
|
||||
if dp.postMCPHook, ok = sym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error)); !ok {
|
||||
return nil, fmt.Errorf("failed to cast PostMCPHook to expected signature")
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: Inject (ObservabilityPlugin)
|
||||
if sym, err := pluginObj.Lookup("Inject"); err == nil {
|
||||
if dp.inject, ok = sym.(func(ctx context.Context, trace *schemas.Trace) error); !ok {
|
||||
return nil, fmt.Errorf("failed to cast Inject to expected signature")
|
||||
}
|
||||
}
|
||||
|
||||
return dp, nil
|
||||
}
|
||||
|
||||
// VerifyBasePlugin verifies a plugin at the given path
|
||||
// Returns the name of the plugin or an empty string if the plugin is invalid
|
||||
// Returns an error if the plugin is invalid
|
||||
// This method is used to verify that the plugin is a valid base plugin and has the required symbols
|
||||
func (l *SharedObjectPluginLoader) VerifyBasePlugin(path string) (string, error) {
|
||||
dp := &DynamicPlugin{
|
||||
Path: path,
|
||||
}
|
||||
pluginObj, err := openPlugin(dp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Required: GetName
|
||||
getNameSym, err := pluginObj.Lookup("GetName")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("required symbol GetName not found: %w", err)
|
||||
}
|
||||
var ok bool
|
||||
if dp.getName, ok = getNameSym.(func() string); !ok {
|
||||
return "", fmt.Errorf("failed to cast GetName to func() string")
|
||||
}
|
||||
// Required: Cleanup
|
||||
cleanupSym, err := pluginObj.Lookup("Cleanup")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("required symbol Cleanup not found: %w", err)
|
||||
}
|
||||
if dp.cleanup, ok = cleanupSym.(func() error); !ok {
|
||||
return "", fmt.Errorf("failed to cast Cleanup to func() error")
|
||||
}
|
||||
return dp.getName(), nil
|
||||
}
|
||||
113
framework/plugins/soplugin.go
Normal file
113
framework/plugins/soplugin.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"plugin"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// DynamicPlugin is a generic dynamic plugin that can implement any combination of plugin interfaces
|
||||
// It uses optional function pointers - nil pointers indicate the interface is not implemented
|
||||
type DynamicPlugin struct {
|
||||
Enabled bool
|
||||
Path string
|
||||
Config any
|
||||
|
||||
filename string
|
||||
plugin *plugin.Plugin
|
||||
|
||||
// BasePlugin (required)
|
||||
getName func() string
|
||||
cleanup func() error
|
||||
|
||||
// HTTPTransportPlugin (optional)
|
||||
httpTransportPreHook func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)
|
||||
httpTransportPostHook func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error
|
||||
httpTransportStreamChunkHook func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, stream *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error)
|
||||
|
||||
// LLMPlugin (optional)
|
||||
preLLMHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)
|
||||
postLLMHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)
|
||||
|
||||
// MCPPlugin (optional)
|
||||
preMCPHook func(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error)
|
||||
postMCPHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error)
|
||||
|
||||
// ObservabilityPlugin (optional)
|
||||
inject func(ctx context.Context, trace *schemas.Trace) error
|
||||
}
|
||||
|
||||
// GetName returns the name of the plugin (BasePlugin interface)
|
||||
func (dp *DynamicPlugin) GetName() string {
|
||||
return dp.getName()
|
||||
}
|
||||
|
||||
// Cleanup is invoked by core/bifrost.go during plugin unload, reload, and shutdown (BasePlugin interface)
|
||||
func (dp *DynamicPlugin) Cleanup() error {
|
||||
return dp.cleanup()
|
||||
}
|
||||
|
||||
// HTTPTransportPreHook intercepts HTTP requests at the transport layer before entering Bifrost core (HTTPTransportPlugin interface)
|
||||
func (dp *DynamicPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
||||
if dp.httpTransportPreHook == nil {
|
||||
return nil, nil // No-op if not implemented
|
||||
}
|
||||
return dp.httpTransportPreHook(ctx, req)
|
||||
}
|
||||
|
||||
// HTTPTransportPostHook intercepts HTTP responses at the transport layer after exiting Bifrost core (HTTPTransportPlugin interface)
|
||||
func (dp *DynamicPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
|
||||
if dp.httpTransportPostHook == nil {
|
||||
return nil // No-op if not implemented
|
||||
}
|
||||
return dp.httpTransportPostHook(ctx, req, resp)
|
||||
}
|
||||
|
||||
// HTTPTransportStreamChunkHook intercepts streaming chunks before they are written to the client
|
||||
func (dp *DynamicPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, stream *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
|
||||
if dp.httpTransportStreamChunkHook == nil {
|
||||
return stream, nil // No-op if not implemented
|
||||
}
|
||||
return dp.httpTransportStreamChunkHook(ctx, req, stream)
|
||||
}
|
||||
|
||||
// PreLLMHook is invoked before LLM provider calls (LLMPlugin interface)
|
||||
func (dp *DynamicPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
||||
if dp.preLLMHook == nil {
|
||||
return req, nil, nil // No-op if not implemented
|
||||
}
|
||||
return dp.preLLMHook(ctx, req)
|
||||
}
|
||||
|
||||
// PostLLMHook is invoked after LLM provider calls (LLMPlugin interface)
|
||||
func (dp *DynamicPlugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
||||
if dp.postLLMHook == nil {
|
||||
return resp, bifrostErr, nil // No-op if not implemented
|
||||
}
|
||||
return dp.postLLMHook(ctx, resp, bifrostErr)
|
||||
}
|
||||
|
||||
// PreMCPHook is invoked before MCP calls (MCPPlugin interface)
|
||||
func (dp *DynamicPlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
|
||||
if dp.preMCPHook == nil {
|
||||
return req, nil, nil // No-op if not implemented
|
||||
}
|
||||
return dp.preMCPHook(ctx, req)
|
||||
}
|
||||
|
||||
// PostMCPHook is invoked after MCP calls (MCPPlugin interface)
|
||||
func (dp *DynamicPlugin) PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
|
||||
if dp.postMCPHook == nil {
|
||||
return resp, bifrostErr, nil // No-op if not implemented
|
||||
}
|
||||
return dp.postMCPHook(ctx, resp, bifrostErr)
|
||||
}
|
||||
|
||||
// Inject receives completed traces for observability backends (ObservabilityPlugin interface)
|
||||
func (dp *DynamicPlugin) Inject(ctx context.Context, trace *schemas.Trace) error {
|
||||
if dp.inject == nil {
|
||||
return nil // No-op if not implemented
|
||||
}
|
||||
return dp.inject(ctx, trace)
|
||||
}
|
||||
809
framework/plugins/soplugin_test.go
Normal file
809
framework/plugins/soplugin_test.go
Normal file
@@ -0,0 +1,809 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
helloWorldPluginDir = "../../examples/plugins/hello-world"
|
||||
helloWorldBuildDir = "../../examples/plugins/hello-world/build"
|
||||
)
|
||||
|
||||
// TestDynamicPluginLifecycle tests the complete lifecycle of a dynamic plugin
|
||||
func TestDynamicPluginLifecycle(t *testing.T) {
|
||||
// Build the hello-world plugin first
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
// Test loading the plugin
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{"test": "config"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
basePlugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "Failed to load plugins")
|
||||
require.Len(t, basePlugins, 1, "Expected exactly one plugin to be loaded")
|
||||
|
||||
plugins := FilterLLMPlugins(basePlugins)
|
||||
require.Len(t, plugins, 1, "Expected plugin to implement LLMPlugin")
|
||||
plugin := plugins[0]
|
||||
|
||||
// Test GetName
|
||||
t.Run("GetName", func(t *testing.T) {
|
||||
name := plugin.GetName()
|
||||
assert.Equal(t, "hello-world", name, "Plugin name should match")
|
||||
})
|
||||
|
||||
// Test HTTPTransportPreHook
|
||||
t.Run("HTTPTransportPreHook", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create a test HTTP request
|
||||
req := &schemas.HTTPRequest{
|
||||
Method: "POST",
|
||||
Path: "/api",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer token123",
|
||||
},
|
||||
Query: map[string]string{},
|
||||
Body: []byte(`{"test": "data"}`),
|
||||
}
|
||||
|
||||
// Call HTTPTransportPreHook
|
||||
httpTransportPlugin, ok := plugin.(schemas.HTTPTransportPlugin)
|
||||
require.True(t, ok, "Plugin should be a HTTPTransportPlugin")
|
||||
resp, err := httpTransportPlugin.HTTPTransportPreHook(pluginCtx, req)
|
||||
require.NoError(t, err, "HTTPTransportPreHook should not return error")
|
||||
assert.Nil(t, resp, "HTTPTransportPreHook should return nil response to continue")
|
||||
|
||||
// Verify headers were modified (hello-world plugin adds a header)
|
||||
assert.Equal(t, "transport-pre-hook-value", req.Headers["x-hello-world-plugin"], "Plugin should have added custom header")
|
||||
})
|
||||
|
||||
// Test HTTPTransportPostHook
|
||||
t.Run("HTTPTransportPostHook", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create a test HTTP response
|
||||
req := &schemas.HTTPRequest{
|
||||
Method: "POST",
|
||||
Path: "/api",
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer token123",
|
||||
},
|
||||
Query: map[string]string{},
|
||||
Body: []byte(`{"test": "data"}`),
|
||||
}
|
||||
resp := &schemas.HTTPResponse{
|
||||
StatusCode: 200,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
Body: []byte(`{"result": "success"}`),
|
||||
}
|
||||
|
||||
// Call HTTPTransportPostHook
|
||||
httpTransportPlugin, ok := plugin.(schemas.HTTPTransportPlugin)
|
||||
require.True(t, ok, "Plugin should be a HTTPTransportPlugin")
|
||||
err := httpTransportPlugin.HTTPTransportPostHook(pluginCtx, req, resp)
|
||||
require.NoError(t, err, "HTTPTransportPostHook should not return error")
|
||||
// Verify headers were modified (hello-world plugin adds a header)
|
||||
assert.Equal(t, "transport-post-hook-value", resp.Headers["x-hello-world-plugin"], "Plugin should have added custom header")
|
||||
})
|
||||
|
||||
// Test PreLLMHook
|
||||
t.Run("PreLLMHook", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := &schemas.BifrostRequest{
|
||||
RequestType: schemas.ChatCompletionRequest,
|
||||
ChatRequest: &schemas.BifrostChatRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: stringPtr("Hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
modifiedReq, shortCircuit, err := plugin.PreLLMHook(pluginCtx, req)
|
||||
require.NoError(t, err, "PreLLMHook should not return error")
|
||||
assert.Nil(t, shortCircuit, "PreLLMHook should not return short circuit")
|
||||
assert.Equal(t, req, modifiedReq, "Request should be unchanged")
|
||||
})
|
||||
|
||||
// Test PostLLMHook
|
||||
t.Run("PostLLMHook", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resp := &schemas.BifrostResponse{
|
||||
ChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "test-id",
|
||||
Model: "gpt-4",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: stringPtr("Hello! How can I help you?"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bifrostErr := (*schemas.BifrostError)(nil)
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
modifiedResp, modifiedErr, err := plugin.PostLLMHook(pluginCtx, resp, bifrostErr)
|
||||
require.NoError(t, err, "PostLLMHook should not return error")
|
||||
assert.Equal(t, resp, modifiedResp, "Response should be unchanged")
|
||||
assert.Equal(t, bifrostErr, modifiedErr, "Error should be unchanged")
|
||||
})
|
||||
|
||||
// Test PostLLMHook with error
|
||||
t.Run("PostHook_WithError", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
statusCode := 500
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "Test error",
|
||||
},
|
||||
}
|
||||
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
modifiedResp, modifiedErr, err := plugin.PostLLMHook(pluginCtx, nil, bifrostErr)
|
||||
require.NoError(t, err, "PostLLMHook should not return error")
|
||||
assert.Nil(t, modifiedResp, "Response should be nil")
|
||||
assert.Equal(t, bifrostErr, modifiedErr, "Error should be unchanged")
|
||||
})
|
||||
|
||||
// Test Cleanup
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
err := plugin.Cleanup()
|
||||
assert.NoError(t, err, "Cleanup should not return error")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadPlugins_DisabledPlugin tests that disabled plugins are not loaded
|
||||
func TestLoadPlugins_DisabledPlugin(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world",
|
||||
Enabled: false, // Plugin is disabled
|
||||
Config: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "LoadPlugins should not error for disabled plugins")
|
||||
assert.Len(t, plugins, 0, "No plugins should be loaded when all are disabled")
|
||||
}
|
||||
|
||||
// TestLoadPlugins_MultiplePlugins tests loading multiple plugins
|
||||
func TestLoadPlugins_MultiplePlugins(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world-1",
|
||||
Enabled: true,
|
||||
Config: nil,
|
||||
},
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world-2",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{"key": "value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "LoadPlugins should succeed for multiple plugins")
|
||||
assert.Len(t, plugins, 2, "Two plugins should be loaded")
|
||||
|
||||
for _, plugin := range plugins {
|
||||
assert.Equal(t, "hello-world", plugin.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadPlugins_InvalidPath tests loading a plugin with invalid path
|
||||
func TestLoadPlugins_InvalidPath(t *testing.T) {
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: "/nonexistent/path/plugin.so",
|
||||
Name: "invalid-plugin",
|
||||
Enabled: true,
|
||||
Config: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
assert.Error(t, err, "LoadPlugins should return error for invalid path")
|
||||
assert.Nil(t, plugins, "No plugins should be loaded on error")
|
||||
}
|
||||
|
||||
// TestLoadPlugins_EmptyConfig tests loading plugins with empty config
|
||||
func TestLoadPlugins_EmptyConfig(t *testing.T) {
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{},
|
||||
}
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "LoadPlugins should succeed with empty config")
|
||||
assert.Len(t, plugins, 0, "No plugins should be loaded with empty config")
|
||||
}
|
||||
|
||||
// TestDynamicPlugin_ContextPropagation tests that context is properly propagated
|
||||
func TestDynamicPlugin_ContextPropagation(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "Failed to load plugin")
|
||||
|
||||
// Type assert to LLMPlugin
|
||||
plugin, ok := basePlugin.(schemas.LLMPlugin)
|
||||
require.True(t, ok, "Plugin should implement LLMPlugin interface")
|
||||
|
||||
// Create a context with a value
|
||||
ctx := context.WithValue(context.Background(), "test-key", "test-value")
|
||||
|
||||
// Test PreLLMHook with context
|
||||
req := &schemas.BifrostRequest{
|
||||
RequestType: schemas.ChatCompletionRequest,
|
||||
ChatRequest: &schemas.BifrostChatRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
_, _, err = plugin.PreLLMHook(pluginCtx, req)
|
||||
require.NoError(t, err, "PreLLMHook should succeed with context")
|
||||
|
||||
// Test PostLLMHook with context
|
||||
resp := &schemas.BifrostResponse{
|
||||
ChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "test-id",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
_, _, err = plugin.PostLLMHook(pluginCtx, resp, nil)
|
||||
require.NoError(t, err, "PostLLMHook should succeed with context")
|
||||
}
|
||||
|
||||
// TestDynamicPlugin_ConcurrentCalls tests concurrent plugin calls
|
||||
func TestDynamicPlugin_ConcurrentCalls(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "Failed to load plugin")
|
||||
|
||||
// Type assert to LLMPlugin
|
||||
plugin, ok := basePlugin.(schemas.LLMPlugin)
|
||||
require.True(t, ok, "Plugin should implement LLMPlugin interface")
|
||||
|
||||
// Run multiple goroutines calling plugin methods
|
||||
const numGoroutines = 10
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
ctx := context.Background()
|
||||
req := &schemas.BifrostRequest{
|
||||
RequestType: schemas.ChatCompletionRequest,
|
||||
ChatRequest: &schemas.BifrostChatRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
// Call PreLLMHook
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
_, _, err := plugin.PreLLMHook(pluginCtx, req)
|
||||
assert.NoError(t, err, "PreLLMHook should succeed in goroutine %d", id)
|
||||
|
||||
// Call PostLLMHook
|
||||
resp := &schemas.BifrostResponse{
|
||||
ChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "test-id",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
_, _, err = plugin.PostLLMHook(pluginCtx, resp, nil)
|
||||
assert.NoError(t, err, "PostLLMHook should succeed in goroutine %d", id)
|
||||
|
||||
// Call GetName
|
||||
name := basePlugin.GetName()
|
||||
assert.Equal(t, "hello-world", name, "GetName should return correct name in goroutine %d", id)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to build the hello-world plugin
|
||||
func buildHelloWorldPlugin(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
// Get absolute path to the hello-world plugin directory
|
||||
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
|
||||
require.NoError(t, err, "Failed to get absolute path")
|
||||
|
||||
// Determine plugin extension based on OS
|
||||
pluginExt := ".so"
|
||||
if runtime.GOOS == "windows" {
|
||||
pluginExt = ".dll"
|
||||
}
|
||||
|
||||
// Clean and create build directory to ensure fresh build with current Go version
|
||||
buildDir := filepath.Join(absPluginDir, "build")
|
||||
os.RemoveAll(buildDir)
|
||||
err = os.MkdirAll(buildDir, 0755)
|
||||
require.NoError(t, err, "Failed to create build directory")
|
||||
|
||||
// Build the plugin directly with go build
|
||||
pluginPath := filepath.Join(buildDir, "hello-world"+pluginExt)
|
||||
args := []string{"build", "-buildmode=plugin", "-o", pluginPath}
|
||||
if raceEnabled {
|
||||
args = append(args, "-race")
|
||||
}
|
||||
args = append(args, "main.go")
|
||||
cmd := exec.Command("go", args...)
|
||||
cmd.Dir = absPluginDir
|
||||
cmd.Env = append(os.Environ(), "CGO_ENABLED=1")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Build output: %s", string(output))
|
||||
require.NoError(t, err, "Failed to build hello-world plugin")
|
||||
}
|
||||
|
||||
// Verify the plugin was built
|
||||
_, err = os.Stat(pluginPath)
|
||||
require.NoError(t, err, "Plugin file should exist after build")
|
||||
|
||||
return pluginPath
|
||||
}
|
||||
|
||||
// Helper function to clean up the hello-world plugin build
|
||||
func cleanupHelloWorldPlugin(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
|
||||
if err != nil {
|
||||
t.Logf("Failed to get absolute path for cleanup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("make", "clean")
|
||||
cmd.Dir = absPluginDir
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Logf("Failed to clean hello-world plugin: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadDynamicPlugin_DirectCall tests loading a plugin directly
|
||||
func TestLoadDynamicPlugin_DirectCall(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugin, err := loader.LoadPlugin(pluginPath, map[string]interface{}{
|
||||
"test": "config",
|
||||
})
|
||||
require.NoError(t, err, "loadDynamicPlugin should succeed")
|
||||
assert.NotNil(t, plugin, "Plugin should not be nil")
|
||||
|
||||
// Verify it's a DynamicPlugin
|
||||
dynamicPlugin, ok := plugin.(*DynamicPlugin)
|
||||
assert.True(t, ok, "Plugin should be a DynamicPlugin")
|
||||
assert.Equal(t, pluginPath, dynamicPlugin.Path)
|
||||
}
|
||||
|
||||
// TestDynamicPlugin_NilConfig tests loading a plugin with nil config
|
||||
func TestDynamicPlugin_NilConfig(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "loadDynamicPlugin should succeed with nil config")
|
||||
assert.NotNil(t, plugin, "Plugin should not be nil")
|
||||
|
||||
// Verify plugin works correctly
|
||||
name := plugin.GetName()
|
||||
assert.Equal(t, "hello-world", name)
|
||||
}
|
||||
|
||||
// TestDynamicPlugin_ShortCircuitNil tests that nil short circuit is handled properly
|
||||
func TestDynamicPlugin_ShortCircuitNil(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "Failed to load plugin")
|
||||
|
||||
// Type assert to LLMPlugin
|
||||
plugin, ok := basePlugin.(schemas.LLMPlugin)
|
||||
require.True(t, ok, "Plugin should implement LLMPlugin interface")
|
||||
|
||||
ctx := context.Background()
|
||||
req := &schemas.BifrostRequest{
|
||||
RequestType: schemas.ChatCompletionRequest,
|
||||
ChatRequest: &schemas.BifrostChatRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
modifiedReq, shortCircuit, err := plugin.PreLLMHook(pluginCtx, req)
|
||||
require.NoError(t, err, "PreLLMHook should succeed")
|
||||
assert.Nil(t, shortCircuit, "Short circuit should be nil")
|
||||
assert.NotNil(t, modifiedReq, "Modified request should not be nil")
|
||||
}
|
||||
|
||||
// BenchmarkDynamicPlugin_PreHook benchmarks the PreLLMHook method
|
||||
func BenchmarkDynamicPlugin_PreHook(b *testing.B) {
|
||||
pluginPath := buildHelloWorldPluginForBenchmark(b)
|
||||
defer cleanupHelloWorldPluginForBenchmark(b)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(b, err, "Failed to load plugin")
|
||||
|
||||
// Type assert to LLMPlugin
|
||||
plugin, ok := basePlugin.(schemas.LLMPlugin)
|
||||
require.True(b, ok, "Plugin should implement LLMPlugin interface")
|
||||
|
||||
ctx := context.Background()
|
||||
req := &schemas.BifrostRequest{
|
||||
RequestType: schemas.ChatCompletionRequest,
|
||||
ChatRequest: &schemas.BifrostChatRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = plugin.PreLLMHook(pluginCtx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDynamicPlugin_PostHook benchmarks the PostLLMHook method
|
||||
func BenchmarkDynamicPlugin_PostHook(b *testing.B) {
|
||||
pluginPath := buildHelloWorldPluginForBenchmark(b)
|
||||
defer cleanupHelloWorldPluginForBenchmark(b)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(b, err, "Failed to load plugin")
|
||||
|
||||
// Type assert to LLMPlugin
|
||||
plugin, ok := basePlugin.(schemas.LLMPlugin)
|
||||
require.True(b, ok, "Plugin should implement LLMPlugin interface")
|
||||
|
||||
ctx := context.Background()
|
||||
resp := &schemas.BifrostResponse{
|
||||
ChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "test-id",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
}
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = plugin.PostLLMHook(pluginCtx, resp, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDynamicPlugin_GetName benchmarks the GetName method
|
||||
func BenchmarkDynamicPlugin_GetName(b *testing.B) {
|
||||
pluginPath := buildHelloWorldPluginForBenchmark(b)
|
||||
defer cleanupHelloWorldPluginForBenchmark(b)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(b, err, "Failed to load plugin")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = plugin.GetName()
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to build plugin for benchmarks
|
||||
func buildHelloWorldPluginForBenchmark(b *testing.B) string {
|
||||
b.Helper()
|
||||
|
||||
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
|
||||
require.NoError(b, err, "Failed to get absolute path")
|
||||
|
||||
pluginExt := ".so"
|
||||
if runtime.GOOS == "windows" {
|
||||
pluginExt = ".dll"
|
||||
}
|
||||
|
||||
// Clean and create build directory to ensure fresh build with current Go version
|
||||
buildDir := filepath.Join(absPluginDir, "build")
|
||||
pluginPath := filepath.Join(buildDir, "hello-world"+pluginExt)
|
||||
os.RemoveAll(buildDir)
|
||||
err = os.MkdirAll(buildDir, 0755)
|
||||
require.NoError(b, err, "Failed to create build directory")
|
||||
|
||||
// Build the plugin directly with go build
|
||||
args := []string{"build", "-buildmode=plugin", "-o", pluginPath}
|
||||
if raceEnabled {
|
||||
args = append(args, "-race")
|
||||
}
|
||||
args = append(args, "main.go")
|
||||
cmd := exec.Command("go", args...)
|
||||
cmd.Dir = absPluginDir
|
||||
cmd.Env = append(os.Environ(), "CGO_ENABLED=1")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
b.Logf("Build output: %s", string(output))
|
||||
require.NoError(b, err, "Failed to build hello-world plugin")
|
||||
}
|
||||
|
||||
return pluginPath
|
||||
}
|
||||
|
||||
// Helper function to clean up plugin for benchmarks
|
||||
func cleanupHelloWorldPluginForBenchmark(b *testing.B) {
|
||||
b.Helper()
|
||||
|
||||
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
|
||||
if err != nil {
|
||||
b.Logf("Failed to get absolute path for cleanup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("make", "clean")
|
||||
cmd.Dir = absPluginDir
|
||||
if err := cmd.Run(); err != nil {
|
||||
b.Logf("Failed to clean hello-world plugin: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicPlugin_GetNameNotEmpty tests that GetName returns non-empty string
|
||||
func TestDynamicPlugin_GetNameNotEmpty(t *testing.T) {
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "Failed to load plugin")
|
||||
|
||||
name := plugin.GetName()
|
||||
assert.NotEmpty(t, name, "Plugin name should not be empty")
|
||||
assert.True(t, strings.Contains(name, "hello-world"), "Plugin name should contain 'hello-world'")
|
||||
}
|
||||
|
||||
// Helper function to create a pointer to a string
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// TestLoadPlugins tests the new generic LoadPlugins function
|
||||
func TestLoadPlugins(t *testing.T) {
|
||||
// Build the hello-world plugin first
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
t.Run("LoadSinglePlugin", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{"test": "config"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "Failed to load plugins")
|
||||
require.Len(t, plugins, 1, "Expected exactly one plugin to be loaded")
|
||||
|
||||
plugin := plugins[0]
|
||||
assert.Equal(t, "hello-world", plugin.GetName())
|
||||
})
|
||||
|
||||
t.Run("LoadMultiplePlugins", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world-1",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{"test": "config1"},
|
||||
},
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world-2",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{"test": "config2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "Failed to load plugins")
|
||||
require.Len(t, plugins, 2, "Expected two plugins to be loaded")
|
||||
})
|
||||
|
||||
t.Run("SkipDisabledPlugins", func(t *testing.T) {
|
||||
config := &Config{
|
||||
Plugins: []PluginConfig{
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world-enabled",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{"test": "config"},
|
||||
},
|
||||
{
|
||||
Path: pluginPath,
|
||||
Name: "hello-world-disabled",
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{"test": "config"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugins, err := LoadPlugins(loader, config)
|
||||
require.NoError(t, err, "Failed to load plugins")
|
||||
require.Len(t, plugins, 1, "Expected only enabled plugin to be loaded")
|
||||
})
|
||||
}
|
||||
|
||||
// TestFilterPlugins tests the plugin filter functions
|
||||
func TestFilterPlugins(t *testing.T) {
|
||||
// Build the hello-world plugin first
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "Failed to load plugin")
|
||||
|
||||
plugins := []schemas.BasePlugin{plugin}
|
||||
|
||||
t.Run("FilterLLMPlugins", func(t *testing.T) {
|
||||
llmPlugins := FilterLLMPlugins(plugins)
|
||||
assert.Len(t, llmPlugins, 1, "hello-world should implement LLMPlugin")
|
||||
})
|
||||
|
||||
t.Run("FilterHTTPTransportPlugins", func(t *testing.T) {
|
||||
httpPlugins := FilterHTTPTransportPlugins(plugins)
|
||||
assert.Len(t, httpPlugins, 1, "hello-world should implement HTTPTransportPlugin")
|
||||
})
|
||||
|
||||
t.Run("FilterMCPPlugins", func(t *testing.T) {
|
||||
mcpPlugins := FilterMCPPlugins(plugins)
|
||||
assert.Len(t, mcpPlugins, 0, "hello-world does not implement MCPPlugin")
|
||||
})
|
||||
|
||||
t.Run("FilterObservabilityPlugins", func(t *testing.T) {
|
||||
obsPlugins := FilterObservabilityPlugins(plugins)
|
||||
assert.Len(t, obsPlugins, 0, "hello-world does not implement ObservabilityPlugin")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadPluginWithOptionalHooks tests that plugins can implement only a subset of hooks
|
||||
func TestLoadPluginWithOptionalHooks(t *testing.T) {
|
||||
// Build the hello-world plugin first
|
||||
pluginPath := buildHelloWorldPlugin(t)
|
||||
defer cleanupHelloWorldPlugin(t)
|
||||
|
||||
loader := &SharedObjectPluginLoader{}
|
||||
plugin, err := loader.LoadPlugin(pluginPath, nil)
|
||||
require.NoError(t, err, "Failed to load plugin")
|
||||
|
||||
// The plugin should load successfully even if it doesn't implement all hooks
|
||||
assert.NotNil(t, plugin, "Plugin should be loaded")
|
||||
|
||||
// Test that DynamicPlugin properly handles unimplemented methods by returning no-op values
|
||||
dynamicPlugin, ok := plugin.(*DynamicPlugin)
|
||||
require.True(t, ok, "Plugin should be a DynamicPlugin")
|
||||
|
||||
// Test MCP hooks (not implemented by hello-world plugin)
|
||||
t.Run("UnimplementedMCPHooks", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// PreMCPHook should return no-op values
|
||||
mcpReq := &schemas.BifrostMCPRequest{}
|
||||
returnedReq, shortCircuit, err := dynamicPlugin.PreMCPHook(pluginCtx, mcpReq)
|
||||
assert.NoError(t, err, "PreMCPHook should not error for unimplemented hook")
|
||||
assert.Equal(t, mcpReq, returnedReq, "PreMCPHook should return original request")
|
||||
assert.Nil(t, shortCircuit, "PreMCPHook should return nil short circuit")
|
||||
|
||||
// PostMCPHook should return no-op values
|
||||
mcpResp := &schemas.BifrostMCPResponse{}
|
||||
bifrostErr := &schemas.BifrostError{}
|
||||
returnedResp, returnedErr, hookErr := dynamicPlugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr)
|
||||
assert.NoError(t, hookErr, "PostMCPHook should not error for unimplemented hook")
|
||||
assert.Equal(t, mcpResp, returnedResp, "PostMCPHook should return original response")
|
||||
assert.Equal(t, bifrostErr, returnedErr, "PostMCPHook should return original error")
|
||||
})
|
||||
|
||||
// Test Observability hooks (not implemented by hello-world plugin)
|
||||
t.Run("UnimplementedObservabilityHooks", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
trace := &schemas.Trace{}
|
||||
err := dynamicPlugin.Inject(ctx, trace)
|
||||
assert.NoError(t, err, "Inject should not error for unimplemented hook")
|
||||
})
|
||||
}
|
||||
111
framework/plugins/utils.go
Normal file
111
framework/plugins/utils.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPluginNotFound = fmt.Errorf("plugin not found")
|
||||
)
|
||||
|
||||
// pluginDownloadClient is a fasthttp client with a larger read buffer to handle
|
||||
// responses with large headers.
|
||||
var pluginDownloadClient = &fasthttp.Client{
|
||||
ReadBufferSize: 64 * 1024, // 64KB, matches the bifrost HTTP server setting
|
||||
}
|
||||
|
||||
// DownloadPlugin downloads a plugin from a URL and returns the local file path
|
||||
func DownloadPlugin(pluginURL string, extension string) (string, error) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
response := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(response)
|
||||
|
||||
req.Header.SetMethod(fasthttp.MethodGet)
|
||||
req.Header.Set("Accept", "application/octet-stream")
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
|
||||
const maxRedirects = 5
|
||||
currentURL := pluginURL
|
||||
for i := 0; i <= maxRedirects; i++ {
|
||||
req.SetRequestURI(currentURL)
|
||||
if i > 0 {
|
||||
response.Reset()
|
||||
}
|
||||
|
||||
if err := pluginDownloadClient.DoTimeout(req, response, 120*time.Second); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
statusCode := response.StatusCode()
|
||||
if statusCode == fasthttp.StatusOK {
|
||||
break
|
||||
}
|
||||
if statusCode >= 300 && statusCode < 400 {
|
||||
if i == maxRedirects {
|
||||
return "", fmt.Errorf("too many redirects downloading plugin")
|
||||
}
|
||||
location := string(response.Header.Peek("Location"))
|
||||
if location == "" {
|
||||
return "", fmt.Errorf("redirect response missing Location header: HTTP %d", statusCode)
|
||||
}
|
||||
loc, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid Location header %q: %w", location, err)
|
||||
}
|
||||
base, err := url.Parse(currentURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid request URL %q: %w", currentURL, err)
|
||||
}
|
||||
currentURL = base.ResolveReference(loc).String()
|
||||
continue
|
||||
}
|
||||
return "", fmt.Errorf("failed to download plugin: HTTP %d", statusCode)
|
||||
}
|
||||
|
||||
// Decompress the response body if it was gzip/deflate compressed
|
||||
// BodyUncompressed handles both gzip and deflate encodings based on Content-Encoding header
|
||||
body, err := response.BodyUncompressed()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decompress response body: %w", err)
|
||||
}
|
||||
|
||||
// Create a unique temporary file for the plugin
|
||||
tempFile, err := os.CreateTemp(os.TempDir(), "bifrost-plugin-*"+extension)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||
}
|
||||
tempPath := tempFile.Name()
|
||||
|
||||
// Write the downloaded body to the temporary file
|
||||
_, err = tempFile.Write(body)
|
||||
if err != nil {
|
||||
tempFile.Close()
|
||||
os.Remove(tempPath)
|
||||
return "", fmt.Errorf("failed to write plugin to temporary file: %w", err)
|
||||
}
|
||||
|
||||
// Close the file
|
||||
err = tempFile.Close()
|
||||
if err != nil {
|
||||
os.Remove(tempPath)
|
||||
return "", fmt.Errorf("failed to close temporary file: %w", err)
|
||||
}
|
||||
|
||||
// Set file permissions to be executable (for .so files)
|
||||
if extension == ".so" {
|
||||
err = os.Chmod(tempPath, 0755)
|
||||
if err != nil {
|
||||
os.Remove(tempPath)
|
||||
return "", fmt.Errorf("failed to set executable permissions on plugin: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tempPath, nil
|
||||
}
|
||||
90
framework/plugins/utils_test.go
Normal file
90
framework/plugins/utils_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const fakePluginBytes = "fake-plugin-binary-content"
|
||||
|
||||
func TestDownloadPlugin_DirectDownload(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(fakePluginBytes))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
path, err := DownloadPlugin(server.URL, ".so")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, fakePluginBytes, string(data))
|
||||
}
|
||||
|
||||
func TestDownloadPlugin_FollowsRedirect(t *testing.T) {
|
||||
// Final destination
|
||||
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(fakePluginBytes))
|
||||
}))
|
||||
defer target.Close()
|
||||
|
||||
// Redirect server (simulates GitHub → S3)
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, target.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
path, err := DownloadPlugin(redirector.URL, ".so")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, fakePluginBytes, string(data))
|
||||
}
|
||||
|
||||
func TestDownloadPlugin_TooManyRedirects(t *testing.T) {
|
||||
// Server that always redirects to itself
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, server.URL, http.StatusFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := DownloadPlugin(server.URL, ".so")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many redirects")
|
||||
}
|
||||
|
||||
func TestDownloadPlugin_NonOKStatus(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := DownloadPlugin(server.URL, ".so")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "404")
|
||||
}
|
||||
|
||||
func TestDownloadPlugin_FileExtensionPreserved(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(fakePluginBytes))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
path, err := DownloadPlugin(server.URL, ".so")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(path)
|
||||
|
||||
assert.Contains(t, path, ".so")
|
||||
}
|
||||
Reference in New Issue
Block a user