first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
package plugins
import "github.com/maximhq/bifrost/core/schemas"
// PluginLoader is the contract for a plugin loader
type PluginLoader interface {
// LoadPlugin loads a generic plugin from the given path with the provided config
// Returns a BasePlugin that can be type-asserted to specific plugin interfaces
LoadPlugin(path string, config any) (schemas.BasePlugin, error)
// VerifyBasePlugin verifies a plugin at the given path
// Returns the name of the plugin or an empty string if the plugin is invalid
// Returns an error if the plugin is invalid
// This method is used to verify that the plugin is a valid base plugin and has the required symbols
VerifyBasePlugin(path string) (string, error)
}

162
framework/plugins/main.go Normal file
View File

@@ -0,0 +1,162 @@
// Package plugins provides a framework for dynamically loading and managing plugins
package plugins
import (
"github.com/maximhq/bifrost/core/schemas"
)
// PluginConfig is the generic configuration for any plugin type
// Plugin types are automatically detected based on implemented interfaces
type PluginConfig struct {
Path string `json:"path"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Config any `json:"config,omitempty"`
}
// Config is the configuration for the plugins framework
type Config struct {
// Plugins is the unified configuration for all plugin types
Plugins []PluginConfig `json:"plugins"`
}
// AsLLMPlugin checks if a base plugin implements LLMPlugin and actually has LLM hooks.
// For DynamicPlugin, it checks if the hook function pointers are not nil.
// Returns nil if the plugin does not implement the interface or has no LLM hooks.
func AsLLMPlugin(plugin schemas.BasePlugin) schemas.LLMPlugin {
// Check if it's a DynamicPlugin first
if dp, ok := plugin.(*DynamicPlugin); ok {
// Only return as LLMPlugin if it actually has LLM hooks
if dp.preLLMHook != nil || dp.postLLMHook != nil {
return dp
}
return nil
}
// For non-DynamicPlugin types, use normal type assertion
if llmPlugin, ok := plugin.(schemas.LLMPlugin); ok {
return llmPlugin
}
return nil
}
// AsMCPPlugin checks if a base plugin implements MCPPlugin and actually has MCP hooks.
// For DynamicPlugin, it checks if the hook function pointers are not nil.
// Returns nil if the plugin does not implement the interface or has no MCP hooks.
func AsMCPPlugin(plugin schemas.BasePlugin) schemas.MCPPlugin {
// Check if it's a DynamicPlugin first
if dp, ok := plugin.(*DynamicPlugin); ok {
// Only return as MCPPlugin if it actually has MCP hooks
if dp.preMCPHook != nil || dp.postMCPHook != nil {
return dp
}
return nil
}
// For non-DynamicPlugin types, use normal type assertion
if mcpPlugin, ok := plugin.(schemas.MCPPlugin); ok {
return mcpPlugin
}
return nil
}
// AsHTTPTransportPlugin checks if a base plugin implements HTTPTransportPlugin and actually has HTTP transport hooks.
// For DynamicPlugin, it checks if the hook function pointers are not nil.
// Returns nil if the plugin does not implement the interface or has no HTTP transport hooks.
func AsHTTPTransportPlugin(plugin schemas.BasePlugin) schemas.HTTPTransportPlugin {
// Check if it's a DynamicPlugin first
if dp, ok := plugin.(*DynamicPlugin); ok {
// Only return as HTTPTransportPlugin if it actually has HTTP transport hooks
if dp.httpTransportPreHook != nil || dp.httpTransportPostHook != nil {
return dp
}
return nil
}
// For non-DynamicPlugin types, use normal type assertion
if httpPlugin, ok := plugin.(schemas.HTTPTransportPlugin); ok {
return httpPlugin
}
return nil
}
// AsObservabilityPlugin checks if a base plugin implements ObservabilityPlugin and actually has observability hooks.
// For DynamicPlugin, it checks if the hook function pointer is not nil.
// Returns nil if the plugin does not implement the interface or has no observability hooks.
func AsObservabilityPlugin(plugin schemas.BasePlugin) schemas.ObservabilityPlugin {
// Check if it's a DynamicPlugin first
if dp, ok := plugin.(*DynamicPlugin); ok {
// Only return as ObservabilityPlugin if it actually has the Inject hook
if dp.inject != nil {
return dp
}
return nil
}
// For non-DynamicPlugin types, use normal type assertion
if obsPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok {
return obsPlugin
}
return nil
}
// LoadPlugins loads all plugins from the config
func LoadPlugins(loader PluginLoader, config *Config) ([]schemas.BasePlugin, error) {
plugins := []schemas.BasePlugin{}
if config == nil {
return plugins, nil
}
for _, pc := range config.Plugins {
if !pc.Enabled {
continue
}
plugin, err := loader.LoadPlugin(pc.Path, pc.Config)
if err != nil {
return nil, err
}
plugins = append(plugins, plugin)
}
return plugins, nil
}
// FilterLLMPlugins filters a list of BasePlugins to only include those implementing LLMPlugin
func FilterLLMPlugins(plugins []schemas.BasePlugin) []schemas.LLMPlugin {
result := []schemas.LLMPlugin{}
for _, p := range plugins {
if llmPlugin := AsLLMPlugin(p); llmPlugin != nil {
result = append(result, llmPlugin)
}
}
return result
}
// FilterMCPPlugins filters a list of BasePlugins to only include those implementing MCPPlugin
func FilterMCPPlugins(plugins []schemas.BasePlugin) []schemas.MCPPlugin {
result := []schemas.MCPPlugin{}
for _, p := range plugins {
if mcpPlugin := AsMCPPlugin(p); mcpPlugin != nil {
result = append(result, mcpPlugin)
}
}
return result
}
// FilterHTTPTransportPlugins filters a list of BasePlugins to only include those implementing HTTPTransportPlugin
func FilterHTTPTransportPlugins(plugins []schemas.BasePlugin) []schemas.HTTPTransportPlugin {
result := []schemas.HTTPTransportPlugin{}
for _, p := range plugins {
if httpPlugin := AsHTTPTransportPlugin(p); httpPlugin != nil {
result = append(result, httpPlugin)
}
}
return result
}
// FilterObservabilityPlugins filters a list of BasePlugins to only include those implementing ObservabilityPlugin
func FilterObservabilityPlugins(plugins []schemas.BasePlugin) []schemas.ObservabilityPlugin {
result := []schemas.ObservabilityPlugin{}
for _, p := range plugins {
if obsPlugin := AsObservabilityPlugin(p); obsPlugin != nil {
result = append(result, obsPlugin)
}
}
return result
}

View File

@@ -0,0 +1,6 @@
//go:build !race
package plugins
// raceEnabled indicates if the binary was built with race detection
const raceEnabled = false

View File

@@ -0,0 +1,6 @@
//go:build race
package plugins
// raceEnabled indicates if the binary was built with race detection
const raceEnabled = true

View File

@@ -0,0 +1,175 @@
package plugins
import (
"context"
"fmt"
"plugin"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// SharedObjectPluginLoader is the loader for shared object plugins
type SharedObjectPluginLoader struct{}
func openPlugin(dp *DynamicPlugin) (*plugin.Plugin, error) {
// Checking if path is URL or file path
if strings.HasPrefix(dp.Path, "http") {
// Download the file
tempPath, err := DownloadPlugin(dp.Path, ".so")
if err != nil {
return nil, err
}
dp.Path = tempPath
}
pluginObj, err := plugin.Open(dp.Path)
if err != nil {
return nil, err
}
dp.plugin = pluginObj
return pluginObj, nil
}
// LoadPlugin loads a generic plugin from a shared object file
// It uses optional symbol lookup - only GetName and Cleanup are required
// All other hook methods are optional and stored as nil if not implemented
func (l *SharedObjectPluginLoader) LoadPlugin(path string, config any) (schemas.BasePlugin, error) {
dp := &DynamicPlugin{
Path: path,
}
pluginObj, err := openPlugin(dp)
if err != nil {
return nil, err
}
// Optional Init method
if initSym, err := pluginObj.Lookup("Init"); err == nil {
if initFunc, ok := initSym.(func(config any) error); ok {
if err := initFunc(config); err != nil {
return nil, fmt.Errorf("plugin Init failed: %w", err)
}
} else {
return nil, fmt.Errorf("failed to cast Init to func(config any) error")
}
}
// Required: GetName
getNameSym, err := pluginObj.Lookup("GetName")
if err != nil {
return nil, fmt.Errorf("required symbol GetName not found: %w", err)
}
var ok bool
if dp.getName, ok = getNameSym.(func() string); !ok {
return nil, fmt.Errorf("failed to cast GetName to func() string\nSee docs for more information: https://docs.getbifrost.ai/plugins/writing-go-plugin")
}
// Required: Cleanup
cleanupSym, err := pluginObj.Lookup("Cleanup")
if err != nil {
return nil, fmt.Errorf("required symbol Cleanup not found: %w", err)
}
if dp.cleanup, ok = cleanupSym.(func() error); !ok {
return nil, fmt.Errorf("failed to cast Cleanup to func() error\nSee docs for more information: https://docs.getbifrost.ai/plugins/writing-go-plugin")
}
// Optional: HTTPTransportPreHook
if sym, err := pluginObj.Lookup("HTTPTransportPreHook"); err == nil {
if dp.httpTransportPreHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)); !ok {
return nil, fmt.Errorf("failed to cast HTTPTransportPreHook to expected signature")
}
}
// Optional: HTTPTransportPostHook
if sym, err := pluginObj.Lookup("HTTPTransportPostHook"); err == nil {
if dp.httpTransportPostHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error); !ok {
return nil, fmt.Errorf("failed to cast HTTPTransportPostHook to expected signature")
}
}
// Optional: HTTPTransportStreamChunkHook
if sym, err := pluginObj.Lookup("HTTPTransportStreamChunkHook"); err == nil {
if dp.httpTransportStreamChunkHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error)); !ok {
return nil, fmt.Errorf("failed to cast HTTPTransportStreamChunkHook to expected signature")
}
}
// Optional: PreLLMHook (with backward compatibility for legacy PreHook)
if sym, err := pluginObj.Lookup("PreLLMHook"); err == nil {
if dp.preLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)); !ok {
return nil, fmt.Errorf("failed to cast PreLLMHook to expected signature")
}
} else if sym, err := pluginObj.Lookup("PreHook"); err == nil {
// Legacy backward compatibility (v1.3.x): treat PreHook as PreLLMHook
if dp.preLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)); !ok {
return nil, fmt.Errorf("failed to cast PreHook to expected signature (legacy backward compatibility)")
}
}
// Optional: PostLLMHook (with backward compatibility for legacy PostHook)
if sym, err := pluginObj.Lookup("PostLLMHook"); err == nil {
if dp.postLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok {
return nil, fmt.Errorf("failed to cast PostLLMHook to expected signature")
}
} else if sym, err := pluginObj.Lookup("PostHook"); err == nil {
// Legacy backward compatibility (v1.3.x): treat PostHook as PostLLMHook
if dp.postLLMHook, ok = sym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok {
return nil, fmt.Errorf("failed to cast PostHook to expected signature (legacy backward compatibility)")
}
}
// Optional: PreMCPHook
if sym, err := pluginObj.Lookup("PreMCPHook"); err == nil {
if dp.preMCPHook, ok = sym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error)); !ok {
return nil, fmt.Errorf("failed to cast PreMCPHook to expected signature")
}
}
// Optional: PostMCPHook
if sym, err := pluginObj.Lookup("PostMCPHook"); err == nil {
if dp.postMCPHook, ok = sym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error)); !ok {
return nil, fmt.Errorf("failed to cast PostMCPHook to expected signature")
}
}
// Optional: Inject (ObservabilityPlugin)
if sym, err := pluginObj.Lookup("Inject"); err == nil {
if dp.inject, ok = sym.(func(ctx context.Context, trace *schemas.Trace) error); !ok {
return nil, fmt.Errorf("failed to cast Inject to expected signature")
}
}
return dp, nil
}
// VerifyBasePlugin verifies a plugin at the given path
// Returns the name of the plugin or an empty string if the plugin is invalid
// Returns an error if the plugin is invalid
// This method is used to verify that the plugin is a valid base plugin and has the required symbols
func (l *SharedObjectPluginLoader) VerifyBasePlugin(path string) (string, error) {
dp := &DynamicPlugin{
Path: path,
}
pluginObj, err := openPlugin(dp)
if err != nil {
return "", err
}
// Required: GetName
getNameSym, err := pluginObj.Lookup("GetName")
if err != nil {
return "", fmt.Errorf("required symbol GetName not found: %w", err)
}
var ok bool
if dp.getName, ok = getNameSym.(func() string); !ok {
return "", fmt.Errorf("failed to cast GetName to func() string")
}
// Required: Cleanup
cleanupSym, err := pluginObj.Lookup("Cleanup")
if err != nil {
return "", fmt.Errorf("required symbol Cleanup not found: %w", err)
}
if dp.cleanup, ok = cleanupSym.(func() error); !ok {
return "", fmt.Errorf("failed to cast Cleanup to func() error")
}
return dp.getName(), nil
}

View File

@@ -0,0 +1,113 @@
package plugins
import (
"context"
"plugin"
"github.com/maximhq/bifrost/core/schemas"
)
// DynamicPlugin is a generic dynamic plugin that can implement any combination of plugin interfaces
// It uses optional function pointers - nil pointers indicate the interface is not implemented
type DynamicPlugin struct {
Enabled bool
Path string
Config any
filename string
plugin *plugin.Plugin
// BasePlugin (required)
getName func() string
cleanup func() error
// HTTPTransportPlugin (optional)
httpTransportPreHook func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)
httpTransportPostHook func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error
httpTransportStreamChunkHook func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, stream *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error)
// LLMPlugin (optional)
preLLMHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error)
postLLMHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)
// MCPPlugin (optional)
preMCPHook func(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error)
postMCPHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error)
// ObservabilityPlugin (optional)
inject func(ctx context.Context, trace *schemas.Trace) error
}
// GetName returns the name of the plugin (BasePlugin interface)
func (dp *DynamicPlugin) GetName() string {
return dp.getName()
}
// Cleanup is invoked by core/bifrost.go during plugin unload, reload, and shutdown (BasePlugin interface)
func (dp *DynamicPlugin) Cleanup() error {
return dp.cleanup()
}
// HTTPTransportPreHook intercepts HTTP requests at the transport layer before entering Bifrost core (HTTPTransportPlugin interface)
func (dp *DynamicPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
if dp.httpTransportPreHook == nil {
return nil, nil // No-op if not implemented
}
return dp.httpTransportPreHook(ctx, req)
}
// HTTPTransportPostHook intercepts HTTP responses at the transport layer after exiting Bifrost core (HTTPTransportPlugin interface)
func (dp *DynamicPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error {
if dp.httpTransportPostHook == nil {
return nil // No-op if not implemented
}
return dp.httpTransportPostHook(ctx, req, resp)
}
// HTTPTransportStreamChunkHook intercepts streaming chunks before they are written to the client
func (dp *DynamicPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, stream *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) {
if dp.httpTransportStreamChunkHook == nil {
return stream, nil // No-op if not implemented
}
return dp.httpTransportStreamChunkHook(ctx, req, stream)
}
// PreLLMHook is invoked before LLM provider calls (LLMPlugin interface)
func (dp *DynamicPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
if dp.preLLMHook == nil {
return req, nil, nil // No-op if not implemented
}
return dp.preLLMHook(ctx, req)
}
// PostLLMHook is invoked after LLM provider calls (LLMPlugin interface)
func (dp *DynamicPlugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
if dp.postLLMHook == nil {
return resp, bifrostErr, nil // No-op if not implemented
}
return dp.postLLMHook(ctx, resp, bifrostErr)
}
// PreMCPHook is invoked before MCP calls (MCPPlugin interface)
func (dp *DynamicPlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
if dp.preMCPHook == nil {
return req, nil, nil // No-op if not implemented
}
return dp.preMCPHook(ctx, req)
}
// PostMCPHook is invoked after MCP calls (MCPPlugin interface)
func (dp *DynamicPlugin) PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
if dp.postMCPHook == nil {
return resp, bifrostErr, nil // No-op if not implemented
}
return dp.postMCPHook(ctx, resp, bifrostErr)
}
// Inject receives completed traces for observability backends (ObservabilityPlugin interface)
func (dp *DynamicPlugin) Inject(ctx context.Context, trace *schemas.Trace) error {
if dp.inject == nil {
return nil // No-op if not implemented
}
return dp.inject(ctx, trace)
}

View File

@@ -0,0 +1,809 @@
package plugins
import (
"context"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
helloWorldPluginDir = "../../examples/plugins/hello-world"
helloWorldBuildDir = "../../examples/plugins/hello-world/build"
)
// TestDynamicPluginLifecycle tests the complete lifecycle of a dynamic plugin
func TestDynamicPluginLifecycle(t *testing.T) {
// Build the hello-world plugin first
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
// Test loading the plugin
config := &Config{
Plugins: []PluginConfig{
{
Path: pluginPath,
Name: "hello-world",
Enabled: true,
Config: map[string]interface{}{"test": "config"},
},
},
}
loader := &SharedObjectPluginLoader{}
basePlugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "Failed to load plugins")
require.Len(t, basePlugins, 1, "Expected exactly one plugin to be loaded")
plugins := FilterLLMPlugins(basePlugins)
require.Len(t, plugins, 1, "Expected plugin to implement LLMPlugin")
plugin := plugins[0]
// Test GetName
t.Run("GetName", func(t *testing.T) {
name := plugin.GetName()
assert.Equal(t, "hello-world", name, "Plugin name should match")
})
// Test HTTPTransportPreHook
t.Run("HTTPTransportPreHook", func(t *testing.T) {
ctx := context.Background()
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
// Create a test HTTP request
req := &schemas.HTTPRequest{
Method: "POST",
Path: "/api",
Headers: map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer token123",
},
Query: map[string]string{},
Body: []byte(`{"test": "data"}`),
}
// Call HTTPTransportPreHook
httpTransportPlugin, ok := plugin.(schemas.HTTPTransportPlugin)
require.True(t, ok, "Plugin should be a HTTPTransportPlugin")
resp, err := httpTransportPlugin.HTTPTransportPreHook(pluginCtx, req)
require.NoError(t, err, "HTTPTransportPreHook should not return error")
assert.Nil(t, resp, "HTTPTransportPreHook should return nil response to continue")
// Verify headers were modified (hello-world plugin adds a header)
assert.Equal(t, "transport-pre-hook-value", req.Headers["x-hello-world-plugin"], "Plugin should have added custom header")
})
// Test HTTPTransportPostHook
t.Run("HTTPTransportPostHook", func(t *testing.T) {
ctx := context.Background()
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
// Create a test HTTP response
req := &schemas.HTTPRequest{
Method: "POST",
Path: "/api",
Headers: map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer token123",
},
Query: map[string]string{},
Body: []byte(`{"test": "data"}`),
}
resp := &schemas.HTTPResponse{
StatusCode: 200,
Headers: map[string]string{
"Content-Type": "application/json",
},
Body: []byte(`{"result": "success"}`),
}
// Call HTTPTransportPostHook
httpTransportPlugin, ok := plugin.(schemas.HTTPTransportPlugin)
require.True(t, ok, "Plugin should be a HTTPTransportPlugin")
err := httpTransportPlugin.HTTPTransportPostHook(pluginCtx, req, resp)
require.NoError(t, err, "HTTPTransportPostHook should not return error")
// Verify headers were modified (hello-world plugin adds a header)
assert.Equal(t, "transport-post-hook-value", resp.Headers["x-hello-world-plugin"], "Plugin should have added custom header")
})
// Test PreLLMHook
t.Run("PreLLMHook", func(t *testing.T) {
ctx := context.Background()
req := &schemas.BifrostRequest{
RequestType: schemas.ChatCompletionRequest,
ChatRequest: &schemas.BifrostChatRequest{
Provider: "openai",
Model: "gpt-4",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: stringPtr("Hello"),
},
},
},
},
}
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
modifiedReq, shortCircuit, err := plugin.PreLLMHook(pluginCtx, req)
require.NoError(t, err, "PreLLMHook should not return error")
assert.Nil(t, shortCircuit, "PreLLMHook should not return short circuit")
assert.Equal(t, req, modifiedReq, "Request should be unchanged")
})
// Test PostLLMHook
t.Run("PostLLMHook", func(t *testing.T) {
ctx := context.Background()
resp := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "test-id",
Model: "gpt-4",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: stringPtr("Hello! How can I help you?"),
},
},
},
},
},
},
}
bifrostErr := (*schemas.BifrostError)(nil)
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
modifiedResp, modifiedErr, err := plugin.PostLLMHook(pluginCtx, resp, bifrostErr)
require.NoError(t, err, "PostLLMHook should not return error")
assert.Equal(t, resp, modifiedResp, "Response should be unchanged")
assert.Equal(t, bifrostErr, modifiedErr, "Error should be unchanged")
})
// Test PostLLMHook with error
t.Run("PostHook_WithError", func(t *testing.T) {
ctx := context.Background()
statusCode := 500
bifrostErr := &schemas.BifrostError{
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: "Test error",
},
}
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
modifiedResp, modifiedErr, err := plugin.PostLLMHook(pluginCtx, nil, bifrostErr)
require.NoError(t, err, "PostLLMHook should not return error")
assert.Nil(t, modifiedResp, "Response should be nil")
assert.Equal(t, bifrostErr, modifiedErr, "Error should be unchanged")
})
// Test Cleanup
t.Run("Cleanup", func(t *testing.T) {
err := plugin.Cleanup()
assert.NoError(t, err, "Cleanup should not return error")
})
}
// TestLoadPlugins_DisabledPlugin tests that disabled plugins are not loaded
func TestLoadPlugins_DisabledPlugin(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
config := &Config{
Plugins: []PluginConfig{
{
Path: pluginPath,
Name: "hello-world",
Enabled: false, // Plugin is disabled
Config: nil,
},
},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "LoadPlugins should not error for disabled plugins")
assert.Len(t, plugins, 0, "No plugins should be loaded when all are disabled")
}
// TestLoadPlugins_MultiplePlugins tests loading multiple plugins
func TestLoadPlugins_MultiplePlugins(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
config := &Config{
Plugins: []PluginConfig{
{
Path: pluginPath,
Name: "hello-world-1",
Enabled: true,
Config: nil,
},
{
Path: pluginPath,
Name: "hello-world-2",
Enabled: true,
Config: map[string]interface{}{"key": "value"},
},
},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "LoadPlugins should succeed for multiple plugins")
assert.Len(t, plugins, 2, "Two plugins should be loaded")
for _, plugin := range plugins {
assert.Equal(t, "hello-world", plugin.GetName())
}
}
// TestLoadPlugins_InvalidPath tests loading a plugin with invalid path
func TestLoadPlugins_InvalidPath(t *testing.T) {
config := &Config{
Plugins: []PluginConfig{
{
Path: "/nonexistent/path/plugin.so",
Name: "invalid-plugin",
Enabled: true,
Config: nil,
},
},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
assert.Error(t, err, "LoadPlugins should return error for invalid path")
assert.Nil(t, plugins, "No plugins should be loaded on error")
}
// TestLoadPlugins_EmptyConfig tests loading plugins with empty config
func TestLoadPlugins_EmptyConfig(t *testing.T) {
config := &Config{
Plugins: []PluginConfig{},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "LoadPlugins should succeed with empty config")
assert.Len(t, plugins, 0, "No plugins should be loaded with empty config")
}
// TestDynamicPlugin_ContextPropagation tests that context is properly propagated
func TestDynamicPlugin_ContextPropagation(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "Failed to load plugin")
// Type assert to LLMPlugin
plugin, ok := basePlugin.(schemas.LLMPlugin)
require.True(t, ok, "Plugin should implement LLMPlugin interface")
// Create a context with a value
ctx := context.WithValue(context.Background(), "test-key", "test-value")
// Test PreLLMHook with context
req := &schemas.BifrostRequest{
RequestType: schemas.ChatCompletionRequest,
ChatRequest: &schemas.BifrostChatRequest{
Provider: "openai",
Model: "gpt-4",
},
}
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
_, _, err = plugin.PreLLMHook(pluginCtx, req)
require.NoError(t, err, "PreLLMHook should succeed with context")
// Test PostLLMHook with context
resp := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "test-id",
Model: "gpt-4",
},
}
_, _, err = plugin.PostLLMHook(pluginCtx, resp, nil)
require.NoError(t, err, "PostLLMHook should succeed with context")
}
// TestDynamicPlugin_ConcurrentCalls tests concurrent plugin calls
func TestDynamicPlugin_ConcurrentCalls(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "Failed to load plugin")
// Type assert to LLMPlugin
plugin, ok := basePlugin.(schemas.LLMPlugin)
require.True(t, ok, "Plugin should implement LLMPlugin interface")
// Run multiple goroutines calling plugin methods
const numGoroutines = 10
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
ctx := context.Background()
req := &schemas.BifrostRequest{
RequestType: schemas.ChatCompletionRequest,
ChatRequest: &schemas.BifrostChatRequest{
Provider: "openai",
Model: "gpt-4",
},
}
// Call PreLLMHook
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
_, _, err := plugin.PreLLMHook(pluginCtx, req)
assert.NoError(t, err, "PreLLMHook should succeed in goroutine %d", id)
// Call PostLLMHook
resp := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "test-id",
Model: "gpt-4",
},
}
_, _, err = plugin.PostLLMHook(pluginCtx, resp, nil)
assert.NoError(t, err, "PostLLMHook should succeed in goroutine %d", id)
// Call GetName
name := basePlugin.GetName()
assert.Equal(t, "hello-world", name, "GetName should return correct name in goroutine %d", id)
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
}
// Helper function to build the hello-world plugin
func buildHelloWorldPlugin(t *testing.T) string {
t.Helper()
// Get absolute path to the hello-world plugin directory
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
require.NoError(t, err, "Failed to get absolute path")
// Determine plugin extension based on OS
pluginExt := ".so"
if runtime.GOOS == "windows" {
pluginExt = ".dll"
}
// Clean and create build directory to ensure fresh build with current Go version
buildDir := filepath.Join(absPluginDir, "build")
os.RemoveAll(buildDir)
err = os.MkdirAll(buildDir, 0755)
require.NoError(t, err, "Failed to create build directory")
// Build the plugin directly with go build
pluginPath := filepath.Join(buildDir, "hello-world"+pluginExt)
args := []string{"build", "-buildmode=plugin", "-o", pluginPath}
if raceEnabled {
args = append(args, "-race")
}
args = append(args, "main.go")
cmd := exec.Command("go", args...)
cmd.Dir = absPluginDir
cmd.Env = append(os.Environ(), "CGO_ENABLED=1")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Build output: %s", string(output))
require.NoError(t, err, "Failed to build hello-world plugin")
}
// Verify the plugin was built
_, err = os.Stat(pluginPath)
require.NoError(t, err, "Plugin file should exist after build")
return pluginPath
}
// Helper function to clean up the hello-world plugin build
func cleanupHelloWorldPlugin(t *testing.T) {
t.Helper()
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
if err != nil {
t.Logf("Failed to get absolute path for cleanup: %v", err)
return
}
cmd := exec.Command("make", "clean")
cmd.Dir = absPluginDir
if err := cmd.Run(); err != nil {
t.Logf("Failed to clean hello-world plugin: %v", err)
}
}
// TestLoadDynamicPlugin_DirectCall tests loading a plugin directly
func TestLoadDynamicPlugin_DirectCall(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
plugin, err := loader.LoadPlugin(pluginPath, map[string]interface{}{
"test": "config",
})
require.NoError(t, err, "loadDynamicPlugin should succeed")
assert.NotNil(t, plugin, "Plugin should not be nil")
// Verify it's a DynamicPlugin
dynamicPlugin, ok := plugin.(*DynamicPlugin)
assert.True(t, ok, "Plugin should be a DynamicPlugin")
assert.Equal(t, pluginPath, dynamicPlugin.Path)
}
// TestDynamicPlugin_NilConfig tests loading a plugin with nil config
func TestDynamicPlugin_NilConfig(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
plugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "loadDynamicPlugin should succeed with nil config")
assert.NotNil(t, plugin, "Plugin should not be nil")
// Verify plugin works correctly
name := plugin.GetName()
assert.Equal(t, "hello-world", name)
}
// TestDynamicPlugin_ShortCircuitNil tests that nil short circuit is handled properly
func TestDynamicPlugin_ShortCircuitNil(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "Failed to load plugin")
// Type assert to LLMPlugin
plugin, ok := basePlugin.(schemas.LLMPlugin)
require.True(t, ok, "Plugin should implement LLMPlugin interface")
ctx := context.Background()
req := &schemas.BifrostRequest{
RequestType: schemas.ChatCompletionRequest,
ChatRequest: &schemas.BifrostChatRequest{
Provider: "openai",
Model: "gpt-4",
},
}
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
modifiedReq, shortCircuit, err := plugin.PreLLMHook(pluginCtx, req)
require.NoError(t, err, "PreLLMHook should succeed")
assert.Nil(t, shortCircuit, "Short circuit should be nil")
assert.NotNil(t, modifiedReq, "Modified request should not be nil")
}
// BenchmarkDynamicPlugin_PreHook benchmarks the PreLLMHook method
func BenchmarkDynamicPlugin_PreHook(b *testing.B) {
pluginPath := buildHelloWorldPluginForBenchmark(b)
defer cleanupHelloWorldPluginForBenchmark(b)
loader := &SharedObjectPluginLoader{}
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(b, err, "Failed to load plugin")
// Type assert to LLMPlugin
plugin, ok := basePlugin.(schemas.LLMPlugin)
require.True(b, ok, "Plugin should implement LLMPlugin interface")
ctx := context.Background()
req := &schemas.BifrostRequest{
RequestType: schemas.ChatCompletionRequest,
ChatRequest: &schemas.BifrostChatRequest{
Provider: "openai",
Model: "gpt-4",
},
}
b.ResetTimer()
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
for i := 0; i < b.N; i++ {
_, _, _ = plugin.PreLLMHook(pluginCtx, req)
}
}
// BenchmarkDynamicPlugin_PostHook benchmarks the PostLLMHook method
func BenchmarkDynamicPlugin_PostHook(b *testing.B) {
pluginPath := buildHelloWorldPluginForBenchmark(b)
defer cleanupHelloWorldPluginForBenchmark(b)
loader := &SharedObjectPluginLoader{}
basePlugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(b, err, "Failed to load plugin")
// Type assert to LLMPlugin
plugin, ok := basePlugin.(schemas.LLMPlugin)
require.True(b, ok, "Plugin should implement LLMPlugin interface")
ctx := context.Background()
resp := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "test-id",
Model: "gpt-4",
},
}
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = plugin.PostLLMHook(pluginCtx, resp, nil)
}
}
// BenchmarkDynamicPlugin_GetName benchmarks the GetName method
func BenchmarkDynamicPlugin_GetName(b *testing.B) {
pluginPath := buildHelloWorldPluginForBenchmark(b)
defer cleanupHelloWorldPluginForBenchmark(b)
loader := &SharedObjectPluginLoader{}
plugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(b, err, "Failed to load plugin")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = plugin.GetName()
}
}
// Helper function to build plugin for benchmarks
func buildHelloWorldPluginForBenchmark(b *testing.B) string {
b.Helper()
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
require.NoError(b, err, "Failed to get absolute path")
pluginExt := ".so"
if runtime.GOOS == "windows" {
pluginExt = ".dll"
}
// Clean and create build directory to ensure fresh build with current Go version
buildDir := filepath.Join(absPluginDir, "build")
pluginPath := filepath.Join(buildDir, "hello-world"+pluginExt)
os.RemoveAll(buildDir)
err = os.MkdirAll(buildDir, 0755)
require.NoError(b, err, "Failed to create build directory")
// Build the plugin directly with go build
args := []string{"build", "-buildmode=plugin", "-o", pluginPath}
if raceEnabled {
args = append(args, "-race")
}
args = append(args, "main.go")
cmd := exec.Command("go", args...)
cmd.Dir = absPluginDir
cmd.Env = append(os.Environ(), "CGO_ENABLED=1")
output, err := cmd.CombinedOutput()
if err != nil {
b.Logf("Build output: %s", string(output))
require.NoError(b, err, "Failed to build hello-world plugin")
}
return pluginPath
}
// Helper function to clean up plugin for benchmarks
func cleanupHelloWorldPluginForBenchmark(b *testing.B) {
b.Helper()
absPluginDir, err := filepath.Abs(helloWorldPluginDir)
if err != nil {
b.Logf("Failed to get absolute path for cleanup: %v", err)
return
}
cmd := exec.Command("make", "clean")
cmd.Dir = absPluginDir
if err := cmd.Run(); err != nil {
b.Logf("Failed to clean hello-world plugin: %v", err)
}
}
// TestDynamicPlugin_GetNameNotEmpty tests that GetName returns non-empty string
func TestDynamicPlugin_GetNameNotEmpty(t *testing.T) {
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
plugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "Failed to load plugin")
name := plugin.GetName()
assert.NotEmpty(t, name, "Plugin name should not be empty")
assert.True(t, strings.Contains(name, "hello-world"), "Plugin name should contain 'hello-world'")
}
// Helper function to create a pointer to a string
func stringPtr(s string) *string {
return &s
}
// TestLoadPlugins tests the new generic LoadPlugins function
func TestLoadPlugins(t *testing.T) {
// Build the hello-world plugin first
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
t.Run("LoadSinglePlugin", func(t *testing.T) {
config := &Config{
Plugins: []PluginConfig{
{
Path: pluginPath,
Name: "hello-world",
Enabled: true,
Config: map[string]interface{}{"test": "config"},
},
},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "Failed to load plugins")
require.Len(t, plugins, 1, "Expected exactly one plugin to be loaded")
plugin := plugins[0]
assert.Equal(t, "hello-world", plugin.GetName())
})
t.Run("LoadMultiplePlugins", func(t *testing.T) {
config := &Config{
Plugins: []PluginConfig{
{
Path: pluginPath,
Name: "hello-world-1",
Enabled: true,
Config: map[string]interface{}{"test": "config1"},
},
{
Path: pluginPath,
Name: "hello-world-2",
Enabled: true,
Config: map[string]interface{}{"test": "config2"},
},
},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "Failed to load plugins")
require.Len(t, plugins, 2, "Expected two plugins to be loaded")
})
t.Run("SkipDisabledPlugins", func(t *testing.T) {
config := &Config{
Plugins: []PluginConfig{
{
Path: pluginPath,
Name: "hello-world-enabled",
Enabled: true,
Config: map[string]interface{}{"test": "config"},
},
{
Path: pluginPath,
Name: "hello-world-disabled",
Enabled: false,
Config: map[string]interface{}{"test": "config"},
},
},
}
loader := &SharedObjectPluginLoader{}
plugins, err := LoadPlugins(loader, config)
require.NoError(t, err, "Failed to load plugins")
require.Len(t, plugins, 1, "Expected only enabled plugin to be loaded")
})
}
// TestFilterPlugins tests the plugin filter functions
func TestFilterPlugins(t *testing.T) {
// Build the hello-world plugin first
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
plugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "Failed to load plugin")
plugins := []schemas.BasePlugin{plugin}
t.Run("FilterLLMPlugins", func(t *testing.T) {
llmPlugins := FilterLLMPlugins(plugins)
assert.Len(t, llmPlugins, 1, "hello-world should implement LLMPlugin")
})
t.Run("FilterHTTPTransportPlugins", func(t *testing.T) {
httpPlugins := FilterHTTPTransportPlugins(plugins)
assert.Len(t, httpPlugins, 1, "hello-world should implement HTTPTransportPlugin")
})
t.Run("FilterMCPPlugins", func(t *testing.T) {
mcpPlugins := FilterMCPPlugins(plugins)
assert.Len(t, mcpPlugins, 0, "hello-world does not implement MCPPlugin")
})
t.Run("FilterObservabilityPlugins", func(t *testing.T) {
obsPlugins := FilterObservabilityPlugins(plugins)
assert.Len(t, obsPlugins, 0, "hello-world does not implement ObservabilityPlugin")
})
}
// TestLoadPluginWithOptionalHooks tests that plugins can implement only a subset of hooks
func TestLoadPluginWithOptionalHooks(t *testing.T) {
// Build the hello-world plugin first
pluginPath := buildHelloWorldPlugin(t)
defer cleanupHelloWorldPlugin(t)
loader := &SharedObjectPluginLoader{}
plugin, err := loader.LoadPlugin(pluginPath, nil)
require.NoError(t, err, "Failed to load plugin")
// The plugin should load successfully even if it doesn't implement all hooks
assert.NotNil(t, plugin, "Plugin should be loaded")
// Test that DynamicPlugin properly handles unimplemented methods by returning no-op values
dynamicPlugin, ok := plugin.(*DynamicPlugin)
require.True(t, ok, "Plugin should be a DynamicPlugin")
// Test MCP hooks (not implemented by hello-world plugin)
t.Run("UnimplementedMCPHooks", func(t *testing.T) {
ctx := context.Background()
pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second)
defer cancel()
// PreMCPHook should return no-op values
mcpReq := &schemas.BifrostMCPRequest{}
returnedReq, shortCircuit, err := dynamicPlugin.PreMCPHook(pluginCtx, mcpReq)
assert.NoError(t, err, "PreMCPHook should not error for unimplemented hook")
assert.Equal(t, mcpReq, returnedReq, "PreMCPHook should return original request")
assert.Nil(t, shortCircuit, "PreMCPHook should return nil short circuit")
// PostMCPHook should return no-op values
mcpResp := &schemas.BifrostMCPResponse{}
bifrostErr := &schemas.BifrostError{}
returnedResp, returnedErr, hookErr := dynamicPlugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr)
assert.NoError(t, hookErr, "PostMCPHook should not error for unimplemented hook")
assert.Equal(t, mcpResp, returnedResp, "PostMCPHook should return original response")
assert.Equal(t, bifrostErr, returnedErr, "PostMCPHook should return original error")
})
// Test Observability hooks (not implemented by hello-world plugin)
t.Run("UnimplementedObservabilityHooks", func(t *testing.T) {
ctx := context.Background()
trace := &schemas.Trace{}
err := dynamicPlugin.Inject(ctx, trace)
assert.NoError(t, err, "Inject should not error for unimplemented hook")
})
}

111
framework/plugins/utils.go Normal file
View File

@@ -0,0 +1,111 @@
package plugins
import (
"fmt"
"net/url"
"os"
"time"
"github.com/valyala/fasthttp"
)
var (
ErrPluginNotFound = fmt.Errorf("plugin not found")
)
// pluginDownloadClient is a fasthttp client with a larger read buffer to handle
// responses with large headers.
var pluginDownloadClient = &fasthttp.Client{
ReadBufferSize: 64 * 1024, // 64KB, matches the bifrost HTTP server setting
}
// DownloadPlugin downloads a plugin from a URL and returns the local file path
func DownloadPlugin(pluginURL string, extension string) (string, error) {
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
response := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(response)
req.Header.SetMethod(fasthttp.MethodGet)
req.Header.Set("Accept", "application/octet-stream")
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
const maxRedirects = 5
currentURL := pluginURL
for i := 0; i <= maxRedirects; i++ {
req.SetRequestURI(currentURL)
if i > 0 {
response.Reset()
}
if err := pluginDownloadClient.DoTimeout(req, response, 120*time.Second); err != nil {
return "", err
}
statusCode := response.StatusCode()
if statusCode == fasthttp.StatusOK {
break
}
if statusCode >= 300 && statusCode < 400 {
if i == maxRedirects {
return "", fmt.Errorf("too many redirects downloading plugin")
}
location := string(response.Header.Peek("Location"))
if location == "" {
return "", fmt.Errorf("redirect response missing Location header: HTTP %d", statusCode)
}
loc, err := url.Parse(location)
if err != nil {
return "", fmt.Errorf("invalid Location header %q: %w", location, err)
}
base, err := url.Parse(currentURL)
if err != nil {
return "", fmt.Errorf("invalid request URL %q: %w", currentURL, err)
}
currentURL = base.ResolveReference(loc).String()
continue
}
return "", fmt.Errorf("failed to download plugin: HTTP %d", statusCode)
}
// Decompress the response body if it was gzip/deflate compressed
// BodyUncompressed handles both gzip and deflate encodings based on Content-Encoding header
body, err := response.BodyUncompressed()
if err != nil {
return "", fmt.Errorf("failed to decompress response body: %w", err)
}
// Create a unique temporary file for the plugin
tempFile, err := os.CreateTemp(os.TempDir(), "bifrost-plugin-*"+extension)
if err != nil {
return "", fmt.Errorf("failed to create temporary file: %w", err)
}
tempPath := tempFile.Name()
// Write the downloaded body to the temporary file
_, err = tempFile.Write(body)
if err != nil {
tempFile.Close()
os.Remove(tempPath)
return "", fmt.Errorf("failed to write plugin to temporary file: %w", err)
}
// Close the file
err = tempFile.Close()
if err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("failed to close temporary file: %w", err)
}
// Set file permissions to be executable (for .so files)
if extension == ".so" {
err = os.Chmod(tempPath, 0755)
if err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("failed to set executable permissions on plugin: %w", err)
}
}
return tempPath, nil
}

View File

@@ -0,0 +1,90 @@
package plugins
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const fakePluginBytes = "fake-plugin-binary-content"
func TestDownloadPlugin_DirectDownload(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fakePluginBytes))
}))
defer server.Close()
path, err := DownloadPlugin(server.URL, ".so")
require.NoError(t, err)
defer os.Remove(path)
data, err := os.ReadFile(path)
require.NoError(t, err)
assert.Equal(t, fakePluginBytes, string(data))
}
func TestDownloadPlugin_FollowsRedirect(t *testing.T) {
// Final destination
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fakePluginBytes))
}))
defer target.Close()
// Redirect server (simulates GitHub → S3)
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target.URL, http.StatusFound)
}))
defer redirector.Close()
path, err := DownloadPlugin(redirector.URL, ".so")
require.NoError(t, err)
defer os.Remove(path)
data, err := os.ReadFile(path)
require.NoError(t, err)
assert.Equal(t, fakePluginBytes, string(data))
}
func TestDownloadPlugin_TooManyRedirects(t *testing.T) {
// Server that always redirects to itself
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, server.URL, http.StatusFound)
}))
defer server.Close()
_, err := DownloadPlugin(server.URL, ".so")
require.Error(t, err)
assert.Contains(t, err.Error(), "too many redirects")
}
func TestDownloadPlugin_NonOKStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
_, err := DownloadPlugin(server.URL, ".so")
require.Error(t, err)
assert.Contains(t, err.Error(), "404")
}
func TestDownloadPlugin_FileExtensionPreserved(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fakePluginBytes))
}))
defer server.Close()
path, err := DownloadPlugin(server.URL, ".so")
require.NoError(t, err)
defer os.Remove(path)
assert.Contains(t, path, ".so")
}