first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,107 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "hello-world-wasm-rust"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "itoa"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]]
name = "memchr"
version = "2.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
name = "proc-macro2"
version = "1.0.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a"
dependencies = [
"proc-macro2",
]
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.149"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
dependencies = [
"itoa",
"memchr",
"serde",
"serde_core",
"zmij",
]
[[package]]
name = "syn"
version = "2.0.114"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "zmij"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fc5a66a20078bf1251bde995aa2fdcc4b800c70b5d92dd2c62abc5c60f679f8"

View File

@@ -0,0 +1,18 @@
[package]
name = "hello-world-wasm-rust"
version = "0.1.0"
edition = "2021"
description = "A minimal Bifrost WASM plugin example in Rust"
[lib]
crate-type = ["cdylib"]
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
[profile.release]
opt-level = "s"
lto = true
strip = true
panic = "abort"

View File

@@ -0,0 +1,80 @@
.PHONY: all build build-optimized clean help check-rust
# Colors
COLOR_RESET = \033[0m
COLOR_INFO = \033[36m
COLOR_SUCCESS = \033[32m
COLOR_WARNING = \033[33m
COLOR_ERROR = \033[31m
COLOR_BOLD = \033[1m
# Plugin configuration
PLUGIN_NAME = hello-world
OUTPUT_DIR = build
OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm
TARGET = wasm32-unknown-unknown
help: ## Show this help message
@echo '$(COLOR_BOLD)Hello World WASM Plugin (Rust)$(COLOR_RESET)'
@echo ''
@echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]'
@echo ''
@echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)'
@echo ' - Rust with wasm32-unknown-unknown target'
@echo ' rustup target add wasm32-unknown-unknown'
@echo ''
@echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)'
@awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST)
check-rust: ## Check if Rust and WASM target are installed
@which cargo > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: Rust/Cargo is not installed$(COLOR_RESET)"; \
echo "$(COLOR_INFO)Install Rust: https://rustup.rs/$(COLOR_RESET)"; \
exit 1)
@rustup target list --installed | grep -q $(TARGET) || (echo "$(COLOR_ERROR)Error: WASM target not installed$(COLOR_RESET)"; \
echo "$(COLOR_INFO)Install with: rustup target add $(TARGET)$(COLOR_RESET)"; \
exit 1)
@echo "$(COLOR_SUCCESS)✓ Rust found: $$(rustc --version)$(COLOR_RESET)"
@echo "$(COLOR_SUCCESS)✓ WASM target: $(TARGET)$(COLOR_RESET)"
build: check-rust ## Build the WASM plugin
@mkdir -p $(OUTPUT_DIR)
@echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)"
cargo build --release --target $(TARGET)
@cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT)
@echo "$(COLOR_SUCCESS)✓ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)"
@ls -lh $(OUTPUT) | awk '{print " Size: " $$5}'
build-optimized: check-rust ## Build with wasm-opt optimization (requires wasm-opt)
@mkdir -p $(OUTPUT_DIR)
@echo "$(COLOR_INFO)Building optimized WASM plugin...$(COLOR_RESET)"
cargo build --release --target $(TARGET)
@cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT)
@if which wasm-opt > /dev/null 2>&1; then \
echo "$(COLOR_INFO)Running wasm-opt...$(COLOR_RESET)"; \
wasm-opt -Os -o $(OUTPUT) $(OUTPUT); \
else \
echo "$(COLOR_WARNING)wasm-opt not found, skipping optimization$(COLOR_RESET)"; \
fi
@echo "$(COLOR_SUCCESS)✓ Plugin built: $(OUTPUT)$(COLOR_RESET)"
@ls -lh $(OUTPUT) | awk '{print " Size: " $$5}'
clean: ## Remove build artifacts
@echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)"
@cargo clean
@rm -rf $(OUTPUT_DIR)
@echo "$(COLOR_SUCCESS)✓ Clean complete$(COLOR_RESET)"
info: ## Show build information
@echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)"
@echo " Plugin Name: $(PLUGIN_NAME)"
@echo " Output: $(OUTPUT)"
@echo " Target: $(TARGET)"
@echo ""
@if [ -f "$(OUTPUT)" ]; then \
echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \
ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \
else \
echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \
fi
.DEFAULT_GOAL := help

View File

@@ -0,0 +1,528 @@
# Bifrost WASM Plugin (Rust)
A comprehensive example of a Bifrost plugin written in Rust and compiled to WebAssembly. This plugin demonstrates proper structure definitions with serde, JSON parsing, context handling, and request/response modification patterns.
## Prerequisites
### Rust Installation
Install Rust from [rustup.rs](https://rustup.rs/) and add the WASM target:
```bash
# Install Rust (if not already installed)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# Add WASM target
rustup target add wasm32-unknown-unknown
```
### Optional: wasm-opt
For smaller binaries, install `wasm-opt` from [binaryen](https://github.com/WebAssembly/binaryen):
```bash
# macOS
brew install binaryen
# Linux
apt install binaryen
```
## Building
```bash
# Build the WASM plugin
make build
# Build with wasm-opt optimization
make build-optimized
# Clean build artifacts
make clean
```
The compiled plugin will be at `build/hello-world.wasm`.
## File Structure
```
src/
├── lib.rs # Plugin implementation (hooks)
├── memory.rs # Memory management utilities
└── types.rs # Type definitions (mirrors Go SDK)
```
## Plugin Structure
WASM plugins must export the following functions:
| Export | Signature | Description |
|--------|-----------|-------------|
| `malloc` | `(size: u32) -> u32` | Allocate memory for host to write data |
| `free` | `(ptr: u32, size: u32)` | Free allocated memory |
| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name |
| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) |
| `http_intercept` | `(input_ptr, input_len: u32) -> u64` | HTTP transport intercept |
| `pre_hook` | `(input_ptr, input_len: u32) -> u64` | Pre-request hook |
| `post_hook` | `(input_ptr, input_len: u32) -> u64` | Post-response hook |
| `cleanup` | `() -> i32` | Cleanup resources (0 = success) |
### Return Value Format
Functions returning data use a packed `u64` format:
- Upper 32 bits: pointer to data in WASM memory
- Lower 32 bits: length of data
## Data Structures
This plugin uses `serde` with derive macros for JSON serialization. All structures mirror the Go SDK types:
### Context
```rust
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostContext {
pub request_id: Option<String>,
// Custom values via HashMap
#[serde(flatten)]
pub values: HashMap<String, serde_json::Value>,
}
impl BifrostContext {
pub fn set_value(&mut self, key: &str, value: impl Into<serde_json::Value>);
pub fn get_string(&self, key: &str) -> Option<&str>;
pub fn get_bool(&self, key: &str) -> Option<bool>;
}
```
### HTTP Transport Types
```rust
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPRequest {
pub method: String,
pub path: String,
pub headers: HashMap<String, String>,
pub query: HashMap<String, String>,
pub body: String, // base64 encoded
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPResponse {
pub status_code: i32,
pub headers: HashMap<String, String>,
pub body: String, // base64 encoded
}
```
### Chat Completion Types
```rust
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ChatMessageRole {
User,
Assistant,
System,
Tool,
Developer,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ChatMessageContent {
Text(String),
Blocks(Vec<ChatContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatMessage {
pub role: ChatMessageRole,
pub content: Option<ChatMessageContent>,
pub name: Option<String>,
pub tool_call_id: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatParameters {
pub temperature: Option<f64>,
pub max_completion_tokens: Option<i32>,
pub top_p: Option<f64>,
pub frequency_penalty: Option<f64>,
pub presence_penalty: Option<f64>,
pub stop: Option<Vec<String>>,
pub tools: Option<Vec<ChatTool>>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostChatRequest {
pub provider: String,
pub model: String,
pub input: Vec<ChatMessage>,
pub params: Option<ChatParameters>,
pub fallbacks: Option<Vec<Fallback>>,
}
```
### Response Types
```rust
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LLMUsage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResponseChoice {
pub index: i32,
pub message: Option<ChatMessage>,
pub delta: Option<ChatMessage>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostChatResponse {
pub id: String,
pub model: String,
pub choices: Vec<ResponseChoice>,
pub usage: Option<LLMUsage>,
pub created: Option<i64>,
pub object: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostResponse {
pub chat_response: Option<BifrostChatResponse>,
}
```
### Error Types
```rust
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ErrorField {
pub message: String,
#[serde(rename = "type")]
pub error_type: Option<String>,
pub code: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostError {
pub error: ErrorField,
pub status_code: Option<i32>,
pub allow_fallbacks: Option<bool>,
}
impl BifrostError {
pub fn new(message: &str) -> Self;
pub fn with_type(self, error_type: &str) -> Self;
pub fn with_code(self, code: &str) -> Self;
pub fn with_status(self, status: i32) -> Self;
}
```
### Short Circuit
```rust
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LLMPluginShortCircuit {
pub response: Option<BifrostResponse>,
pub error: Option<BifrostError>,
}
```
## Hook Input/Output Structures
### http_intercept
**Input:**
```json
{
"context": { "request_id": "abc-123" },
"request": {
"method": "POST",
"path": "/v1/chat/completions",
"headers": { "Content-Type": "application/json" },
"query": {},
"body": "<base64-encoded>"
}
}
```
**Output:**
```json
{
"context": { "request_id": "abc-123" },
"request": {},
"response": { "status_code": 200, "headers": {}, "body": "<base64>" },
"has_response": false,
"error": ""
}
```
### pre_hook
**Input:**
```json
{
"context": { "request_id": "abc-123" },
"request": {
"provider": "openai",
"model": "gpt-4",
"input": [{ "role": "user", "content": "Hello" }],
"params": { "temperature": 0.7 }
}
}
```
**Output:**
```json
{
"context": { "request_id": "abc-123", "plugin_processed": true },
"request": {},
"short_circuit": {
"response": { "chat_response": { ... } }
},
"has_short_circuit": false,
"error": ""
}
```
### post_hook
**Input:**
```json
{
"context": { "request_id": "abc-123", "plugin_processed": true },
"response": {
"chat_response": {
"id": "chatcmpl-123",
"model": "gpt-4",
"choices": [{ "index": 0, "message": { "role": "assistant", "content": "Hi!" } }],
"usage": { "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15 }
}
},
"error": {},
"has_error": false
}
```
**Output:**
```json
{
"context": { "request_id": "abc-123", "post_hook_completed": true },
"response": {},
"error": {},
"has_error": false,
"hook_error": ""
}
```
## Usage Examples
### Modifying Context
```rust
#[no_mangle]
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
let input: PreHookInput = serde_json::from_str(&input_str).unwrap();
let mut output = PreHookOutput {
context: input.context.clone(),
..Default::default()
};
// Add custom values to context
output.context.set_value("plugin_processed", serde_json::json!(true));
output.context.set_value("plugin_name", serde_json::json!("my-rust-plugin"));
write_string(&serde_json::to_string(&output).unwrap())
}
```
### Short-Circuit with Mock Response
```rust
#[no_mangle]
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
let input: PreHookInput = serde_json::from_str(&input_str).unwrap();
let (provider, model) = input.get_provider_model();
// Check if this should be mocked
if model == "mock-model" {
let mut output = PreHookOutput {
context: input.context.clone(),
has_short_circuit: true,
..Default::default()
};
// Build mock response
let mock_response = BifrostResponse {
chat_response: Some(BifrostChatResponse {
id: format!("mock-{}", input.context.request_id.unwrap_or_default()),
model: "mock-model".to_string(),
choices: vec![ResponseChoice {
index: 0,
message: Some(ChatMessage {
role: ChatMessageRole::Assistant,
content: Some(ChatMessageContent::Text(
"This is a mock response!".to_string()
)),
..Default::default()
}),
finish_reason: Some("stop".to_string()),
..Default::default()
}],
usage: Some(LLMUsage {
prompt_tokens: 10,
completion_tokens: 15,
total_tokens: 25,
..Default::default()
}),
..Default::default()
}),
..Default::default()
};
output.short_circuit = Some(LLMPluginShortCircuit {
response: Some(mock_response),
error: None,
});
return write_string(&serde_json::to_string(&output).unwrap());
}
// Pass through
let output = PreHookOutput {
context: input.context,
..Default::default()
};
write_string(&serde_json::to_string(&output).unwrap())
}
```
### Short-Circuit with Error
```rust
#[no_mangle]
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
let input: PreHookInput = serde_json::from_str(&input_str).unwrap();
// Check rate limit (example)
if should_rate_limit(&input.context) {
let mut output = PreHookOutput {
context: input.context.clone(),
has_short_circuit: true,
..Default::default()
};
output.short_circuit = Some(LLMPluginShortCircuit {
response: None,
error: Some(
BifrostError::new("Rate limit exceeded")
.with_type("rate_limit")
.with_code("429")
.with_status(429)
),
});
return write_string(&serde_json::to_string(&output).unwrap());
}
// Pass through
let output = PreHookOutput {
context: input.context,
..Default::default()
};
write_string(&serde_json::to_string(&output).unwrap())
}
```
### Modifying Responses in post_hook
```rust
#[no_mangle]
pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
let input: PostHookInput = serde_json::from_str(&input_str).unwrap();
let mut output = PostHookOutput {
context: input.context.clone(),
..Default::default()
};
// Handle errors
if input.has_error {
output.has_error = true;
output.error = input.error.clone();
// Optionally modify the error
if let Some(mut error) = input.parse_error() {
error.error.message = format!("{} (via rust plugin)", error.error.message);
output.error = serde_json::to_value(&error).unwrap_or_default();
}
return write_string(&serde_json::to_string(&output).unwrap());
}
// Pass through or modify response
if let Some(mut response) = input.parse_response() {
if let Some(ref mut chat) = response.chat_response {
// Add a marker to the model name
chat.model = format!("{} (via rust-wasm)", chat.model);
}
output.response = serde_json::to_value(&response).unwrap_or_default();
}
write_string(&serde_json::to_string(&output).unwrap())
}
```
## Usage with Bifrost
Configure the plugin in your Bifrost config:
```json
{
"plugins": [
{
"path": "/path/to/hello-world.wasm",
"name": "hello-world-wasm-rust",
"enabled": true,
"config": {
"custom_option": "value"
}
}
]
}
```
## Testing
The plugin includes unit tests that can be run with:
```bash
cargo test
```
## Benefits
1. **Performance**: Rust compiles to highly optimized WASM
2. **Safety**: Memory safety without garbage collection
3. **Small binaries**: Rust WASM binaries are typically very small
4. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture
5. **Security**: WASM provides sandboxed execution
6. **Type Safety**: Strongly typed structures with serde derive macros
7. **Excellent JSON**: serde_json provides robust JSON handling

View File

@@ -0,0 +1,327 @@
//! Bifrost WASM Plugin for Rust
//!
//! This plugin demonstrates the proper structure for parsing inputs,
//! building responses, and handling context - similar to Go plugin patterns.
//!
//! Build with: cargo build --release --target wasm32-unknown-unknown
mod memory;
mod types;
use memory::{read_string, write_string};
use types::*;
// Global configuration storage
static mut PLUGIN_CONFIG: Option<PluginConfig> = None;
// =============================================================================
// Exported Plugin Functions
// =============================================================================
/// Return the plugin name
#[no_mangle]
pub extern "C" fn get_name() -> u64 {
write_string("hello-world-wasm-rust")
}
/// Initialize the plugin with config
/// Returns 0 on success, non-zero on error
#[no_mangle]
pub extern "C" fn init(config_ptr: u32, config_len: u32) -> i32 {
let config_str = read_string(config_ptr, config_len);
// Parse configuration
let config: PluginConfig = if config_str.is_empty() {
PluginConfig::default()
} else {
match serde_json::from_str(&config_str) {
Ok(c) => c,
Err(_) => return 1, // Config parse error
}
};
// Store configuration
unsafe {
PLUGIN_CONFIG = Some(config);
}
0 // Success
}
/// HTTP transport intercept
/// Called at the HTTP layer before request enters Bifrost core.
/// Can modify headers, query params, or short-circuit with a response.
#[no_mangle]
pub extern "C" fn http_intercept(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
// Parse input
let input: HTTPInterceptInput = match serde_json::from_str(&input_str) {
Ok(i) => i,
Err(e) => {
// Include context around the error position for debugging
let error_context = if let Some(col) = extract_column(&e.to_string()) {
let start = col.saturating_sub(50);
let end = (col + 50).min(input_str.len());
format!(" | context: ...{}...", &input_str[start..end])
} else {
String::new()
};
let output = HTTPInterceptOutput {
error: format!("Failed to parse input: {}{}", e, error_context),
..Default::default()
};
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
};
// Add context value like Go plugin does
let mut context = input.context;
context.set_value("from-http", serde_json::json!("123"));
// Create output with context and request preserved (pass-through)
// Serialize request to Value to ensure proper JSON structure
let request_value = serde_json::to_value(&input.request).ok();
let output = HTTPInterceptOutput {
context: input.context,
request: input.request,
has_response: false,
..Default::default()
};
// Pass through
write_string(&serde_json::to_string(&output).unwrap_or_default())
}
/// Pre-request hook
/// Called before request is sent to the provider.
/// Can modify the request or short-circuit with a response/error.
#[no_mangle]
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
// Parse input
let input: PreHookInput = match serde_json::from_str(&input_str) {
Ok(i) => i,
Err(e) => {
let output = PreHookOutput {
error: format!("Failed to parse input: {}", e),
..Default::default()
};
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
};
// Create output with context preserved
let mut output = PreHookOutput {
context: input.context.clone(),
request: input.request.clone(),
has_short_circuit: false,
..Default::default()
};
// Get provider and model for potential modifications
let (_provider, model) = input.get_provider_model();
// Example: Short-circuit with mock response for specific model
// Uncomment to test:
/*
if model == "mock-model" {
output.has_short_circuit = true;
let mock_response = BifrostResponse {
chat_response: Some(BifrostChatResponse {
id: format!("mock-{}", input.context.request_id.unwrap_or_default()),
model: "mock-model".to_string(),
choices: vec![ResponseChoice {
index: 0,
message: Some(ChatMessage {
role: ChatMessageRole::Assistant,
content: Some(ChatMessageContent::Text(
"This is a mock response from the Rust WASM plugin!".to_string()
)),
..Default::default()
}),
finish_reason: Some("stop".to_string()),
..Default::default()
}],
usage: Some(LLMUsage {
prompt_tokens: 10,
completion_tokens: 15,
total_tokens: 25,
..Default::default()
}),
..Default::default()
}),
..Default::default()
};
output.short_circuit = Some(LLMPluginShortCircuit {
response: Some(mock_response),
error: None,
});
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
*/
// Example: Short-circuit with rate limit error
// Uncomment to test:
/*
if should_rate_limit(&input.context) {
output.has_short_circuit = true;
output.short_circuit = Some(LLMPluginShortCircuit {
response: None,
error: Some(
BifrostError::new("Rate limit exceeded")
.with_type("rate_limit")
.with_code("429")
.with_status(429)
),
});
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
*/
// Silence unused variable warning in example code
let _ = model;
// Pass through - empty request means use original
write_string(&serde_json::to_string(&output).unwrap_or_default())
}
/// Post-response hook
/// Called after response is received from provider.
/// Can modify the response or error.
#[no_mangle]
pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
// Parse input
let input: PostHookInput = match serde_json::from_str(&input_str) {
Ok(i) => i,
Err(e) => {
let output = PostHookOutput {
hook_error: format!("Failed to parse input: {}", e),
..Default::default()
};
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
};
// Add context value like Go plugin does
let mut context = input.context.clone();
context.set_value("from-post-hook", serde_json::json!("456"));
// Create output with context and response/error preserved (pass-through)
// This matches Go plugin behavior exactly
let output = PostHookOutput {
context,
response: Some(input.response.clone()),
error: Some(input.error.clone()),
has_error: input.has_error,
hook_error: String::new(),
};
// Example: Modify error message when has_error is true
// Uncomment to test:
/*
if input.has_error {
if let Some(mut error) = input.parse_error() {
error.error.message = format!("{} (processed by Rust WASM plugin)", error.error.message);
let mut output = output;
output.error = Some(serde_json::to_value(&error).unwrap_or_default());
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
}
*/
// Example: Modify response
// Uncomment to test:
/*
if let Some(mut response) = input.parse_response() {
// Add custom metadata, modify model name, etc.
if let Some(ref mut chat) = response.chat_response {
// Add a marker to the model name
chat.model = format!("{} (via rust-wasm)", chat.model);
}
let mut output = output;
output.response = Some(serde_json::to_value(&response).unwrap_or_default());
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
*/
write_string(&serde_json::to_string(&output).unwrap_or_default())
}
/// HTTP stream chunk hook
/// Called for each chunk during streaming responses.
/// Can modify, skip, or stop streaming based on return values.
#[no_mangle]
pub extern "C" fn http_stream_chunk_hook(input_ptr: u32, input_len: u32) -> u64 {
let input_str = read_string(input_ptr, input_len);
// Parse input
let input: HTTPStreamChunkHookInput = match serde_json::from_str(&input_str) {
Ok(i) => i,
Err(e) => {
let output = HTTPStreamChunkHookOutput {
error: format!("Failed to parse input: {}", e),
..Default::default()
};
return write_string(&serde_json::to_string(&output).unwrap_or_default());
}
};
// Add context value like Go plugin does
let mut context = input.context.clone();
context.set_value("from-stream-chunk", serde_json::json!("rust-wasm"));
// Pass through chunk unchanged
let output = HTTPStreamChunkHookOutput {
context,
chunk: Some(input.chunk),
has_chunk: true,
skip: false,
error: String::new(),
};
write_string(&serde_json::to_string(&output).unwrap_or_default())
}
/// Cleanup resources
/// Called when plugin is being unloaded.
/// Returns 0 on success, non-zero on error
#[no_mangle]
pub extern "C" fn cleanup() -> i32 {
// Clear stored configuration
unsafe {
PLUGIN_CONFIG = None;
}
0 // Success
}
// =============================================================================
// Helper Functions
// =============================================================================
/// Extract column number from serde error message for debugging
fn extract_column(error_msg: &str) -> Option<usize> {
// Error format: "... at line X column Y"
if let Some(idx) = error_msg.rfind("column ") {
let col_str = &error_msg[idx + 7..];
col_str.split_whitespace().next()?.parse().ok()
} else {
None
}
}
/// Example rate limit check function
#[allow(dead_code)]
fn should_rate_limit(_context: &BifrostContext) -> bool {
// Implement your rate limiting logic here
false
}

View File

@@ -0,0 +1,70 @@
//! Memory management utilities for WASM plugins.
//! Handles allocation, deallocation, and string read/write operations.
use std::alloc::{alloc, dealloc, Layout};
use std::slice;
/// Pack a pointer and length into a single u64
/// Upper 32 bits: pointer, Lower 32 bits: length
pub fn pack_result(ptr: u32, len: u32) -> u64 {
((ptr as u64) << 32) | (len as u64)
}
/// Write a string to WASM memory and return packed pointer+length
pub fn write_string(s: &str) -> u64 {
if s.is_empty() {
return 0;
}
let bytes = s.as_bytes();
let ptr = unsafe { malloc(bytes.len() as u32) };
if ptr == 0 {
return 0;
}
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr as *mut u8, bytes.len());
}
pack_result(ptr, bytes.len() as u32)
}
/// Read a string from WASM memory given pointer and length
pub fn read_string(ptr: u32, len: u32) -> String {
if len == 0 {
return String::new();
}
let bytes = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) };
String::from_utf8_lossy(bytes).into_owned()
}
/// Allocate memory for the host to write data
///
/// # Safety
/// This function is marked as safe but performs unsafe operations internally.
/// It is intended to be called from WASM host.
#[no_mangle]
pub extern "C" fn malloc(size: u32) -> u32 {
if size == 0 {
return 0;
}
let layout = match Layout::from_size_align(size as usize, 1) {
Ok(l) => l,
Err(_) => return 0,
};
unsafe { alloc(layout) as u32 }
}
/// Free allocated memory
///
/// # Safety
/// This function is marked as safe but performs unsafe operations internally.
/// It is intended to be called from WASM host.
#[no_mangle]
pub extern "C" fn free(ptr: u32, size: u32) {
if ptr == 0 || size == 0 {
return;
}
let layout = match Layout::from_size_align(size as usize, 1) {
Ok(l) => l,
Err(_) => return,
};
unsafe { dealloc(ptr as *mut u8, layout) }
}

View File

@@ -0,0 +1,870 @@
//! Type definitions for Bifrost WASM plugins.
//! These structures mirror the Go SDK types for interoperability.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// =============================================================================
// Nullable Deserializers
// =============================================================================
/// Helper module for deserializing fields that may be null in JSON.
/// Go's JSON encoder outputs `null` for nil slices/maps, but Rust's serde
/// with `#[serde(default)]` only handles missing fields, not explicit nulls.
mod nullable {
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
/// Deserialize a string that may be null, converting null to empty string.
pub fn string<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
}
/// Deserialize a HashMap<String, String> that may be null or contain null values.
/// Handles both `null` (entire map is null) and `{"key": null}` (value is null).
pub fn string_map<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
where
D: Deserializer<'de>,
{
// First deserialize as Option<HashMap<String, Option<String>>> to handle null values
let opt_map: Option<HashMap<String, Option<String>>> = Option::deserialize(deserializer)?;
match opt_map {
None => Ok(HashMap::new()),
Some(map) => {
// Filter out null values and unwrap the rest
Ok(map
.into_iter()
.filter_map(|(k, v)| v.map(|val| (k, val)))
.collect())
}
}
}
/// Deserialize an i32 that may be null, converting null to 0.
pub fn i32_field<'de, D>(deserializer: D) -> Result<i32, D::Error>
where
D: Deserializer<'de>,
{
Option::<i32>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
}
/// Deserialize an HTTPRequest that may be null, converting null to default.
pub fn http_request<'de, D>(deserializer: D) -> Result<super::HTTPRequest, D::Error>
where
D: Deserializer<'de>,
{
Option::<super::HTTPRequest>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
}
/// Deserialize a BifrostContext that may be null, converting null to default.
pub fn context<'de, D>(deserializer: D) -> Result<super::BifrostContext, D::Error>
where
D: Deserializer<'de>,
{
Option::<super::BifrostContext>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
}
}
// =============================================================================
// Context Structure
// =============================================================================
/// BifrostContext holds request-scoped values passed between hooks.
/// This is a dynamic map (map[string]any in Go) that can hold any JSON values.
/// Common keys include:
/// - request_id: Unique identifier for the request
/// - Custom plugin values can be added and will be persisted across hooks
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct BifrostContext(pub HashMap<String, serde_json::Value>);
impl BifrostContext {
pub fn new() -> Self {
Self(HashMap::new())
}
/// Set a custom value in the context
pub fn set_value(&mut self, key: &str, value: impl Into<serde_json::Value>) {
self.0.insert(key.to_string(), value.into());
}
/// Get a value from the context
pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
self.0.get(key)
}
/// Get a string value from the context
pub fn get_string(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|v| v.as_str())
}
/// Get a boolean value from the context
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.0.get(key).and_then(|v| v.as_bool())
}
/// Get an i64 value from the context
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.0.get(key).and_then(|v| v.as_i64())
}
/// Check if a key exists in the context
pub fn contains_key(&self, key: &str) -> bool {
self.0.contains_key(key)
}
/// Remove a value from the context
pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
self.0.remove(key)
}
/// Get the underlying HashMap for iteration
pub fn inner(&self) -> &HashMap<String, serde_json::Value> {
&self.0
}
/// Get mutable access to the underlying HashMap
pub fn inner_mut(&mut self) -> &mut HashMap<String, serde_json::Value> {
&mut self.0
}
}
// =============================================================================
// HTTP Transport Structures
// =============================================================================
/// HTTPRequest represents an incoming HTTP request at the transport layer.
/// Body is base64-encoded.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPRequest {
#[serde(default, deserialize_with = "nullable::string")]
pub method: String,
#[serde(default, deserialize_with = "nullable::string")]
pub path: String,
#[serde(default, deserialize_with = "nullable::string_map")]
pub headers: HashMap<String, String>,
#[serde(default, deserialize_with = "nullable::string_map")]
pub query: HashMap<String, String>,
/// Base64-encoded request body
#[serde(default, deserialize_with = "nullable::string")]
pub body: String,
}
/// HTTPResponse represents an HTTP response to return.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPResponse {
#[serde(default, deserialize_with = "nullable::i32_field")]
pub status_code: i32,
#[serde(default, deserialize_with = "nullable::string_map")]
pub headers: HashMap<String, String>,
/// Base64-encoded response body
#[serde(default, deserialize_with = "nullable::string")]
pub body: String,
}
/// HTTPInterceptInput is the input for http_intercept hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPInterceptInput {
#[serde(default, deserialize_with = "nullable::context")]
pub context: BifrostContext,
#[serde(default, deserialize_with = "nullable::http_request")]
pub request: HTTPRequest,
}
/// HTTPInterceptOutput is the output for http_intercept hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPInterceptOutput {
pub context: BifrostContext,
#[serde(skip_serializing_if = "Option::is_none")]
pub request: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<HTTPResponse>,
#[serde(default)]
pub has_response: bool,
#[serde(default)]
pub error: String,
}
// =============================================================================
// Chat Completion Structures (BifrostRequest)
// =============================================================================
/// ChatMessageRole represents the role of a message sender.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ChatMessageRole {
User,
Assistant,
System,
Tool,
Developer,
}
impl Default for ChatMessageRole {
fn default() -> Self {
ChatMessageRole::User
}
}
/// ChatMessageContent can be either a string or an array of content blocks.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ChatMessageContent {
Text(String),
Blocks(Vec<ChatContentBlock>),
}
impl Default for ChatMessageContent {
fn default() -> Self {
ChatMessageContent::Text(String::new())
}
}
/// ChatContentBlock represents a content block in a message.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatContentBlock {
#[serde(rename = "type")]
pub block_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>,
}
/// ImageUrl represents an image URL in a content block.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
/// ChatMessage represents a message in the conversation.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatMessage {
#[serde(default)]
pub role: ChatMessageRole,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<ChatMessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
/// ToolCall represents a tool call made by the assistant.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolCall {
#[serde(default)]
pub id: Option<String>,
#[serde(rename = "type", default)]
pub call_type: Option<String>,
#[serde(default)]
pub function: ToolCallFunction,
}
/// ToolCallFunction represents the function being called.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolCallFunction {
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub arguments: String,
}
/// ChatParameters contains optional parameters for chat completion.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatTool>>,
/// Catch-all for additional parameters
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
/// ChatTool represents a tool definition.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatTool {
#[serde(rename = "type")]
pub tool_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<ChatToolFunction>,
}
/// ChatToolFunction represents a function definition.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatToolFunction {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}
/// BifrostChatRequest represents a chat completion request.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostChatRequest {
#[serde(default)]
pub provider: String,
#[serde(default)]
pub model: String,
#[serde(default)]
pub input: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<ChatParameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fallbacks: Option<Vec<Fallback>>,
}
/// Fallback represents a fallback provider/model.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Fallback {
pub provider: String,
pub model: String,
}
/// BifrostRequest is the unified request structure.
/// Only one of the request types should be present.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_request: Option<BifrostChatRequest>,
// Add other request types as needed
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
impl BifrostRequest {
/// Get provider and model from the request
pub fn get_provider_model(&self) -> (String, String) {
if let Some(ref chat) = self.chat_request {
return (chat.provider.clone(), chat.model.clone());
}
(String::new(), String::new())
}
}
// =============================================================================
// Response Structures (BifrostResponse)
// =============================================================================
/// LLMUsage contains token usage information.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LLMUsage {
#[serde(default)]
pub prompt_tokens: i32,
#[serde(default)]
pub completion_tokens: i32,
#[serde(default)]
pub total_tokens: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<serde_json::Value>,
}
/// ResponseChoice represents a single completion choice.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResponseChoice {
#[serde(default)]
pub index: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delta: Option<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
}
/// BifrostChatResponse represents a chat completion response.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostChatResponse {
#[serde(default)]
pub id: String,
#[serde(default)]
pub model: String,
#[serde(default)]
pub choices: Vec<ResponseChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<LLMUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub object: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
/// BifrostResponse is the unified response structure.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_response: Option<BifrostChatResponse>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
// =============================================================================
// Error Structure
// =============================================================================
/// ErrorField contains the error details.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ErrorField {
#[serde(default)]
pub message: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
pub error_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
}
/// BifrostError represents an error response.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BifrostError {
#[serde(default)]
pub error: ErrorField,
#[serde(skip_serializing_if = "Option::is_none")]
pub status_code: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_fallbacks: Option<bool>,
}
impl BifrostError {
/// Create a new error with a message
pub fn new(message: &str) -> Self {
Self {
error: ErrorField {
message: message.to_string(),
..Default::default()
},
..Default::default()
}
}
/// Set the error type
pub fn with_type(mut self, error_type: &str) -> Self {
self.error.error_type = Some(error_type.to_string());
self
}
/// Set the error code
pub fn with_code(mut self, code: &str) -> Self {
self.error.code = Some(code.to_string());
self
}
/// Set the status code
pub fn with_status(mut self, status: i32) -> Self {
self.status_code = Some(status);
self
}
}
// =============================================================================
// Short Circuit Structure
// =============================================================================
/// LLMPluginShortCircuit allows plugins to short-circuit the request flow.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LLMPluginShortCircuit {
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<BifrostResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<BifrostError>,
}
// =============================================================================
// Hook Input/Output Structures
// =============================================================================
/// PreHookInput is the input for pre_hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PreHookInput {
#[serde(default)]
pub context: BifrostContext,
#[serde(default)]
pub request: serde_json::Value,
}
impl PreHookInput {
/// Parse the request as a BifrostRequest
pub fn parse_request(&self) -> Option<BifrostRequest> {
serde_json::from_value(self.request.clone()).ok()
}
/// Get provider and model from the request
pub fn get_provider_model(&self) -> (String, String) {
if let Some(req) = self.parse_request() {
return req.get_provider_model();
}
// Try direct access for simpler structures
let provider = self.request.get("provider")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let model = self.request.get("model")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
(provider, model)
}
}
/// PreHookOutput is the output for pre_hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PreHookOutput {
pub context: BifrostContext,
#[serde(skip_serializing_if = "Option::is_none")]
pub request: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub short_circuit: Option<LLMPluginShortCircuit>,
#[serde(default)]
pub has_short_circuit: bool,
#[serde(default)]
pub error: String,
}
/// PostHookInput is the input for post_hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PostHookInput {
#[serde(default)]
pub context: BifrostContext,
#[serde(default)]
pub response: serde_json::Value,
#[serde(default)]
pub error: serde_json::Value,
#[serde(default)]
pub has_error: bool,
}
impl PostHookInput {
/// Parse the response as a BifrostResponse
pub fn parse_response(&self) -> Option<BifrostResponse> {
serde_json::from_value(self.response.clone()).ok()
}
/// Parse the error as a BifrostError
pub fn parse_error(&self) -> Option<BifrostError> {
if self.has_error {
serde_json::from_value(self.error.clone()).ok()
} else {
None
}
}
}
/// PostHookOutput is the output for post_hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PostHookOutput {
pub context: BifrostContext,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<serde_json::Value>,
#[serde(default)]
pub has_error: bool,
#[serde(default)]
pub hook_error: String,
}
// =============================================================================
// HTTP Stream Chunk Hook Input/Output Structures
// =============================================================================
/// HTTPStreamChunkHookInput is the input for http_stream_chunk_hook.
/// Called for each chunk during streaming responses.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPStreamChunkHookInput {
#[serde(default)]
pub context: BifrostContext,
#[serde(default)]
pub request: serde_json::Value,
#[serde(default)]
pub chunk: serde_json::Value, // BifrostStreamChunk as JSON
}
/// HTTPStreamChunkHookOutput is the output for http_stream_chunk_hook.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HTTPStreamChunkHookOutput {
pub context: BifrostContext,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk: Option<serde_json::Value>, // BifrostStreamChunk as JSON, None to skip
#[serde(default)]
pub has_chunk: bool,
#[serde(default)]
pub skip: bool,
#[serde(default)]
pub error: String,
}
// =============================================================================
// Plugin Configuration
// =============================================================================
/// Plugin configuration (customize as needed)
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PluginConfig {
#[serde(flatten)]
pub values: HashMap<String, serde_json::Value>,
}
// =============================================================================
// Tests
// =============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_serialization() {
let mut ctx = BifrostContext::new();
ctx.set_value("request_id", "test-123");
ctx.set_value("custom_key", "custom_value");
ctx.set_value("is_enabled", true);
ctx.set_value("count", 42);
let json = serde_json::to_string(&ctx).unwrap();
assert!(json.contains("request_id"));
assert!(json.contains("custom_key"));
assert!(json.contains("is_enabled"));
assert!(json.contains("count"));
}
#[test]
fn test_context_deserialization() {
let json = r#"{"request_id": "test-123", "custom_key": "custom_value", "is_enabled": true}"#;
let ctx: BifrostContext = serde_json::from_str(json).unwrap();
assert_eq!(ctx.get_string("request_id"), Some("test-123"));
assert_eq!(ctx.get_string("custom_key"), Some("custom_value"));
assert_eq!(ctx.get_bool("is_enabled"), Some(true));
}
#[test]
fn test_context_methods() {
let mut ctx = BifrostContext::new();
ctx.set_value("key1", "value1");
ctx.set_value("enabled", true);
ctx.set_value("count", 42);
assert_eq!(ctx.get_string("key1"), Some("value1"));
assert_eq!(ctx.get_bool("enabled"), Some(true));
assert_eq!(ctx.get_i64("count"), Some(42));
assert!(ctx.contains_key("key1"));
assert!(!ctx.contains_key("nonexistent"));
ctx.remove("key1");
assert!(!ctx.contains_key("key1"));
}
#[test]
fn test_chat_message() {
let msg = ChatMessage {
role: ChatMessageRole::User,
content: Some(ChatMessageContent::Text("Hello!".to_string())),
..Default::default()
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("user"));
assert!(json.contains("Hello!"));
}
#[test]
fn test_bifrost_error() {
let error = BifrostError::new("Test error")
.with_type("test_type")
.with_code("500")
.with_status(500);
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("Test error"));
assert!(json.contains("test_type"));
}
#[test]
fn test_pre_hook_input_parsing() {
let json = r#"{
"context": {"request_id": "test-123", "custom": "value"},
"request": {"provider": "openai", "model": "gpt-4"}
}"#;
let input: PreHookInput = serde_json::from_str(json).unwrap();
assert_eq!(input.context.get_string("request_id"), Some("test-123"));
assert_eq!(input.context.get_string("custom"), Some("value"));
let (provider, model) = input.get_provider_model();
assert_eq!(provider, "openai");
assert_eq!(model, "gpt-4");
}
#[test]
fn test_http_request_with_null_fields() {
// Simulates Go sending null for nil []byte and nil maps
let json = r#"{
"method": "POST",
"path": "/v1/chat/completions",
"headers": null,
"query": null,
"body": null
}"#;
let req: HTTPRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.method, "POST");
assert_eq!(req.path, "/v1/chat/completions");
assert!(req.headers.is_empty());
assert!(req.query.is_empty());
assert_eq!(req.body, "");
}
#[test]
fn test_http_request_with_missing_fields() {
// Test that missing fields also work (default behavior)
let json = r#"{
"method": "GET",
"path": "/health"
}"#;
let req: HTTPRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.method, "GET");
assert_eq!(req.path, "/health");
assert!(req.headers.is_empty());
assert!(req.query.is_empty());
assert_eq!(req.body, "");
}
#[test]
fn test_http_intercept_input_with_nulls() {
// Simulates a full HTTP intercept input with null body from Go
let json = r#"{
"context": {"request_id": "abc-123"},
"request": {
"method": "POST",
"path": "/v1/chat/completions",
"headers": {"content-type": "application/json"},
"query": {},
"body": null
}
}"#;
let input: HTTPInterceptInput = serde_json::from_str(json).unwrap();
assert_eq!(input.context.get_string("request_id"), Some("abc-123"));
assert_eq!(input.request.method, "POST");
assert_eq!(input.request.path, "/v1/chat/completions");
assert_eq!(input.request.headers.get("content-type"), Some(&"application/json".to_string()));
assert_eq!(input.request.body, "");
}
#[test]
fn test_http_response_with_null_fields() {
let json = r#"{
"status_code": null,
"headers": null,
"body": null
}"#;
let resp: HTTPResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.status_code, 0);
assert!(resp.headers.is_empty());
assert_eq!(resp.body, "");
}
}