first commit
This commit is contained in:
107
examples/plugins/hello-world-wasm-rust/Cargo.lock
generated
Normal file
107
examples/plugins/hello-world-wasm-rust/Cargo.lock
generated
Normal file
@@ -0,0 +1,107 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "hello-world-wasm-rust"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.105"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_core"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.149"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"serde",
|
||||
"serde_core",
|
||||
"zmij",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.114"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "zmij"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fc5a66a20078bf1251bde995aa2fdcc4b800c70b5d92dd2c62abc5c60f679f8"
|
||||
18
examples/plugins/hello-world-wasm-rust/Cargo.toml
Normal file
18
examples/plugins/hello-world-wasm-rust/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "hello-world-wasm-rust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "A minimal Bifrost WASM plugin example in Rust"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "s"
|
||||
lto = true
|
||||
strip = true
|
||||
panic = "abort"
|
||||
80
examples/plugins/hello-world-wasm-rust/Makefile
Normal file
80
examples/plugins/hello-world-wasm-rust/Makefile
Normal file
@@ -0,0 +1,80 @@
|
||||
.PHONY: all build build-optimized clean help check-rust
|
||||
|
||||
# Colors
|
||||
COLOR_RESET = \033[0m
|
||||
COLOR_INFO = \033[36m
|
||||
COLOR_SUCCESS = \033[32m
|
||||
COLOR_WARNING = \033[33m
|
||||
COLOR_ERROR = \033[31m
|
||||
COLOR_BOLD = \033[1m
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_NAME = hello-world
|
||||
OUTPUT_DIR = build
|
||||
OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm
|
||||
TARGET = wasm32-unknown-unknown
|
||||
|
||||
help: ## Show this help message
|
||||
@echo '$(COLOR_BOLD)Hello World WASM Plugin (Rust)$(COLOR_RESET)'
|
||||
@echo ''
|
||||
@echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]'
|
||||
@echo ''
|
||||
@echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)'
|
||||
@echo ' - Rust with wasm32-unknown-unknown target'
|
||||
@echo ' rustup target add wasm32-unknown-unknown'
|
||||
@echo ''
|
||||
@echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)'
|
||||
@awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST)
|
||||
|
||||
check-rust: ## Check if Rust and WASM target are installed
|
||||
@which cargo > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: Rust/Cargo is not installed$(COLOR_RESET)"; \
|
||||
echo "$(COLOR_INFO)Install Rust: https://rustup.rs/$(COLOR_RESET)"; \
|
||||
exit 1)
|
||||
@rustup target list --installed | grep -q $(TARGET) || (echo "$(COLOR_ERROR)Error: WASM target not installed$(COLOR_RESET)"; \
|
||||
echo "$(COLOR_INFO)Install with: rustup target add $(TARGET)$(COLOR_RESET)"; \
|
||||
exit 1)
|
||||
@echo "$(COLOR_SUCCESS)✓ Rust found: $$(rustc --version)$(COLOR_RESET)"
|
||||
@echo "$(COLOR_SUCCESS)✓ WASM target: $(TARGET)$(COLOR_RESET)"
|
||||
|
||||
build: check-rust ## Build the WASM plugin
|
||||
@mkdir -p $(OUTPUT_DIR)
|
||||
@echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)"
|
||||
cargo build --release --target $(TARGET)
|
||||
@cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT)
|
||||
@echo "$(COLOR_SUCCESS)✓ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)"
|
||||
@ls -lh $(OUTPUT) | awk '{print " Size: " $$5}'
|
||||
|
||||
build-optimized: check-rust ## Build with wasm-opt optimization (requires wasm-opt)
|
||||
@mkdir -p $(OUTPUT_DIR)
|
||||
@echo "$(COLOR_INFO)Building optimized WASM plugin...$(COLOR_RESET)"
|
||||
cargo build --release --target $(TARGET)
|
||||
@cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT)
|
||||
@if which wasm-opt > /dev/null 2>&1; then \
|
||||
echo "$(COLOR_INFO)Running wasm-opt...$(COLOR_RESET)"; \
|
||||
wasm-opt -Os -o $(OUTPUT) $(OUTPUT); \
|
||||
else \
|
||||
echo "$(COLOR_WARNING)wasm-opt not found, skipping optimization$(COLOR_RESET)"; \
|
||||
fi
|
||||
@echo "$(COLOR_SUCCESS)✓ Plugin built: $(OUTPUT)$(COLOR_RESET)"
|
||||
@ls -lh $(OUTPUT) | awk '{print " Size: " $$5}'
|
||||
|
||||
clean: ## Remove build artifacts
|
||||
@echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)"
|
||||
@cargo clean
|
||||
@rm -rf $(OUTPUT_DIR)
|
||||
@echo "$(COLOR_SUCCESS)✓ Clean complete$(COLOR_RESET)"
|
||||
|
||||
info: ## Show build information
|
||||
@echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)"
|
||||
@echo " Plugin Name: $(PLUGIN_NAME)"
|
||||
@echo " Output: $(OUTPUT)"
|
||||
@echo " Target: $(TARGET)"
|
||||
@echo ""
|
||||
@if [ -f "$(OUTPUT)" ]; then \
|
||||
echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \
|
||||
ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \
|
||||
else \
|
||||
echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \
|
||||
fi
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
528
examples/plugins/hello-world-wasm-rust/README.md
Normal file
528
examples/plugins/hello-world-wasm-rust/README.md
Normal file
@@ -0,0 +1,528 @@
|
||||
# Bifrost WASM Plugin (Rust)
|
||||
|
||||
A comprehensive example of a Bifrost plugin written in Rust and compiled to WebAssembly. This plugin demonstrates proper structure definitions with serde, JSON parsing, context handling, and request/response modification patterns.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Rust Installation
|
||||
|
||||
Install Rust from [rustup.rs](https://rustup.rs/) and add the WASM target:
|
||||
|
||||
```bash
|
||||
# Install Rust (if not already installed)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
|
||||
# Add WASM target
|
||||
rustup target add wasm32-unknown-unknown
|
||||
```
|
||||
|
||||
### Optional: wasm-opt
|
||||
|
||||
For smaller binaries, install `wasm-opt` from [binaryen](https://github.com/WebAssembly/binaryen):
|
||||
|
||||
```bash
|
||||
# macOS
|
||||
brew install binaryen
|
||||
|
||||
# Linux
|
||||
apt install binaryen
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
```bash
|
||||
# Build the WASM plugin
|
||||
make build
|
||||
|
||||
# Build with wasm-opt optimization
|
||||
make build-optimized
|
||||
|
||||
# Clean build artifacts
|
||||
make clean
|
||||
```
|
||||
|
||||
The compiled plugin will be at `build/hello-world.wasm`.
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── lib.rs # Plugin implementation (hooks)
|
||||
├── memory.rs # Memory management utilities
|
||||
└── types.rs # Type definitions (mirrors Go SDK)
|
||||
```
|
||||
|
||||
## Plugin Structure
|
||||
|
||||
WASM plugins must export the following functions:
|
||||
|
||||
| Export | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `malloc` | `(size: u32) -> u32` | Allocate memory for host to write data |
|
||||
| `free` | `(ptr: u32, size: u32)` | Free allocated memory |
|
||||
| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name |
|
||||
| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) |
|
||||
| `http_intercept` | `(input_ptr, input_len: u32) -> u64` | HTTP transport intercept |
|
||||
| `pre_hook` | `(input_ptr, input_len: u32) -> u64` | Pre-request hook |
|
||||
| `post_hook` | `(input_ptr, input_len: u32) -> u64` | Post-response hook |
|
||||
| `cleanup` | `() -> i32` | Cleanup resources (0 = success) |
|
||||
|
||||
### Return Value Format
|
||||
|
||||
Functions returning data use a packed `u64` format:
|
||||
- Upper 32 bits: pointer to data in WASM memory
|
||||
- Lower 32 bits: length of data
|
||||
|
||||
## Data Structures
|
||||
|
||||
This plugin uses `serde` with derive macros for JSON serialization. All structures mirror the Go SDK types:
|
||||
|
||||
### Context
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostContext {
|
||||
pub request_id: Option<String>,
|
||||
|
||||
// Custom values via HashMap
|
||||
#[serde(flatten)]
|
||||
pub values: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl BifrostContext {
|
||||
pub fn set_value(&mut self, key: &str, value: impl Into<serde_json::Value>);
|
||||
pub fn get_string(&self, key: &str) -> Option<&str>;
|
||||
pub fn get_bool(&self, key: &str) -> Option<bool>;
|
||||
}
|
||||
```
|
||||
|
||||
### HTTP Transport Types
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPRequest {
|
||||
pub method: String,
|
||||
pub path: String,
|
||||
pub headers: HashMap<String, String>,
|
||||
pub query: HashMap<String, String>,
|
||||
pub body: String, // base64 encoded
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPResponse {
|
||||
pub status_code: i32,
|
||||
pub headers: HashMap<String, String>,
|
||||
pub body: String, // base64 encoded
|
||||
}
|
||||
```
|
||||
|
||||
### Chat Completion Types
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ChatMessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
Tool,
|
||||
Developer,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessageContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ChatContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatMessage {
|
||||
pub role: ChatMessageRole,
|
||||
pub content: Option<ChatMessageContent>,
|
||||
pub name: Option<String>,
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatParameters {
|
||||
pub temperature: Option<f64>,
|
||||
pub max_completion_tokens: Option<i32>,
|
||||
pub top_p: Option<f64>,
|
||||
pub frequency_penalty: Option<f64>,
|
||||
pub presence_penalty: Option<f64>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub tools: Option<Vec<ChatTool>>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostChatRequest {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub input: Vec<ChatMessage>,
|
||||
pub params: Option<ChatParameters>,
|
||||
pub fallbacks: Option<Vec<Fallback>>,
|
||||
}
|
||||
```
|
||||
|
||||
### Response Types
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct LLMUsage {
|
||||
pub prompt_tokens: i32,
|
||||
pub completion_tokens: i32,
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ResponseChoice {
|
||||
pub index: i32,
|
||||
pub message: Option<ChatMessage>,
|
||||
pub delta: Option<ChatMessage>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostChatResponse {
|
||||
pub id: String,
|
||||
pub model: String,
|
||||
pub choices: Vec<ResponseChoice>,
|
||||
pub usage: Option<LLMUsage>,
|
||||
pub created: Option<i64>,
|
||||
pub object: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostResponse {
|
||||
pub chat_response: Option<BifrostChatResponse>,
|
||||
}
|
||||
```
|
||||
|
||||
### Error Types
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ErrorField {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: Option<String>,
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostError {
|
||||
pub error: ErrorField,
|
||||
pub status_code: Option<i32>,
|
||||
pub allow_fallbacks: Option<bool>,
|
||||
}
|
||||
|
||||
impl BifrostError {
|
||||
pub fn new(message: &str) -> Self;
|
||||
pub fn with_type(self, error_type: &str) -> Self;
|
||||
pub fn with_code(self, code: &str) -> Self;
|
||||
pub fn with_status(self, status: i32) -> Self;
|
||||
}
|
||||
```
|
||||
|
||||
### Short Circuit
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct LLMPluginShortCircuit {
|
||||
pub response: Option<BifrostResponse>,
|
||||
pub error: Option<BifrostError>,
|
||||
}
|
||||
```
|
||||
|
||||
## Hook Input/Output Structures
|
||||
|
||||
### http_intercept
|
||||
|
||||
**Input:**
|
||||
```json
|
||||
{
|
||||
"context": { "request_id": "abc-123" },
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"path": "/v1/chat/completions",
|
||||
"headers": { "Content-Type": "application/json" },
|
||||
"query": {},
|
||||
"body": "<base64-encoded>"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```json
|
||||
{
|
||||
"context": { "request_id": "abc-123" },
|
||||
"request": {},
|
||||
"response": { "status_code": 200, "headers": {}, "body": "<base64>" },
|
||||
"has_response": false,
|
||||
"error": ""
|
||||
}
|
||||
```
|
||||
|
||||
### pre_hook
|
||||
|
||||
**Input:**
|
||||
```json
|
||||
{
|
||||
"context": { "request_id": "abc-123" },
|
||||
"request": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"input": [{ "role": "user", "content": "Hello" }],
|
||||
"params": { "temperature": 0.7 }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```json
|
||||
{
|
||||
"context": { "request_id": "abc-123", "plugin_processed": true },
|
||||
"request": {},
|
||||
"short_circuit": {
|
||||
"response": { "chat_response": { ... } }
|
||||
},
|
||||
"has_short_circuit": false,
|
||||
"error": ""
|
||||
}
|
||||
```
|
||||
|
||||
### post_hook
|
||||
|
||||
**Input:**
|
||||
```json
|
||||
{
|
||||
"context": { "request_id": "abc-123", "plugin_processed": true },
|
||||
"response": {
|
||||
"chat_response": {
|
||||
"id": "chatcmpl-123",
|
||||
"model": "gpt-4",
|
||||
"choices": [{ "index": 0, "message": { "role": "assistant", "content": "Hi!" } }],
|
||||
"usage": { "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15 }
|
||||
}
|
||||
},
|
||||
"error": {},
|
||||
"has_error": false
|
||||
}
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```json
|
||||
{
|
||||
"context": { "request_id": "abc-123", "post_hook_completed": true },
|
||||
"response": {},
|
||||
"error": {},
|
||||
"has_error": false,
|
||||
"hook_error": ""
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Modifying Context
|
||||
|
||||
```rust
|
||||
#[no_mangle]
|
||||
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
let input: PreHookInput = serde_json::from_str(&input_str).unwrap();
|
||||
|
||||
let mut output = PreHookOutput {
|
||||
context: input.context.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Add custom values to context
|
||||
output.context.set_value("plugin_processed", serde_json::json!(true));
|
||||
output.context.set_value("plugin_name", serde_json::json!("my-rust-plugin"));
|
||||
|
||||
write_string(&serde_json::to_string(&output).unwrap())
|
||||
}
|
||||
```
|
||||
|
||||
### Short-Circuit with Mock Response
|
||||
|
||||
```rust
|
||||
#[no_mangle]
|
||||
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
let input: PreHookInput = serde_json::from_str(&input_str).unwrap();
|
||||
|
||||
let (provider, model) = input.get_provider_model();
|
||||
|
||||
// Check if this should be mocked
|
||||
if model == "mock-model" {
|
||||
let mut output = PreHookOutput {
|
||||
context: input.context.clone(),
|
||||
has_short_circuit: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Build mock response
|
||||
let mock_response = BifrostResponse {
|
||||
chat_response: Some(BifrostChatResponse {
|
||||
id: format!("mock-{}", input.context.request_id.unwrap_or_default()),
|
||||
model: "mock-model".to_string(),
|
||||
choices: vec![ResponseChoice {
|
||||
index: 0,
|
||||
message: Some(ChatMessage {
|
||||
role: ChatMessageRole::Assistant,
|
||||
content: Some(ChatMessageContent::Text(
|
||||
"This is a mock response!".to_string()
|
||||
)),
|
||||
..Default::default()
|
||||
}),
|
||||
finish_reason: Some("stop".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
usage: Some(LLMUsage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 15,
|
||||
total_tokens: 25,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
output.short_circuit = Some(LLMPluginShortCircuit {
|
||||
response: Some(mock_response),
|
||||
error: None,
|
||||
});
|
||||
|
||||
return write_string(&serde_json::to_string(&output).unwrap());
|
||||
}
|
||||
|
||||
// Pass through
|
||||
let output = PreHookOutput {
|
||||
context: input.context,
|
||||
..Default::default()
|
||||
};
|
||||
write_string(&serde_json::to_string(&output).unwrap())
|
||||
}
|
||||
```
|
||||
|
||||
### Short-Circuit with Error
|
||||
|
||||
```rust
|
||||
#[no_mangle]
|
||||
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
let input: PreHookInput = serde_json::from_str(&input_str).unwrap();
|
||||
|
||||
// Check rate limit (example)
|
||||
if should_rate_limit(&input.context) {
|
||||
let mut output = PreHookOutput {
|
||||
context: input.context.clone(),
|
||||
has_short_circuit: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
output.short_circuit = Some(LLMPluginShortCircuit {
|
||||
response: None,
|
||||
error: Some(
|
||||
BifrostError::new("Rate limit exceeded")
|
||||
.with_type("rate_limit")
|
||||
.with_code("429")
|
||||
.with_status(429)
|
||||
),
|
||||
});
|
||||
|
||||
return write_string(&serde_json::to_string(&output).unwrap());
|
||||
}
|
||||
|
||||
// Pass through
|
||||
let output = PreHookOutput {
|
||||
context: input.context,
|
||||
..Default::default()
|
||||
};
|
||||
write_string(&serde_json::to_string(&output).unwrap())
|
||||
}
|
||||
```
|
||||
|
||||
### Modifying Responses in post_hook
|
||||
|
||||
```rust
|
||||
#[no_mangle]
|
||||
pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
let input: PostHookInput = serde_json::from_str(&input_str).unwrap();
|
||||
|
||||
let mut output = PostHookOutput {
|
||||
context: input.context.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Handle errors
|
||||
if input.has_error {
|
||||
output.has_error = true;
|
||||
output.error = input.error.clone();
|
||||
|
||||
// Optionally modify the error
|
||||
if let Some(mut error) = input.parse_error() {
|
||||
error.error.message = format!("{} (via rust plugin)", error.error.message);
|
||||
output.error = serde_json::to_value(&error).unwrap_or_default();
|
||||
}
|
||||
|
||||
return write_string(&serde_json::to_string(&output).unwrap());
|
||||
}
|
||||
|
||||
// Pass through or modify response
|
||||
if let Some(mut response) = input.parse_response() {
|
||||
if let Some(ref mut chat) = response.chat_response {
|
||||
// Add a marker to the model name
|
||||
chat.model = format!("{} (via rust-wasm)", chat.model);
|
||||
}
|
||||
output.response = serde_json::to_value(&response).unwrap_or_default();
|
||||
}
|
||||
|
||||
write_string(&serde_json::to_string(&output).unwrap())
|
||||
}
|
||||
```
|
||||
|
||||
## Usage with Bifrost
|
||||
|
||||
Configure the plugin in your Bifrost config:
|
||||
|
||||
```json
|
||||
{
|
||||
"plugins": [
|
||||
{
|
||||
"path": "/path/to/hello-world.wasm",
|
||||
"name": "hello-world-wasm-rust",
|
||||
"enabled": true,
|
||||
"config": {
|
||||
"custom_option": "value"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The plugin includes unit tests that can be run with:
|
||||
|
||||
```bash
|
||||
cargo test
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Performance**: Rust compiles to highly optimized WASM
|
||||
2. **Safety**: Memory safety without garbage collection
|
||||
3. **Small binaries**: Rust WASM binaries are typically very small
|
||||
4. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture
|
||||
5. **Security**: WASM provides sandboxed execution
|
||||
6. **Type Safety**: Strongly typed structures with serde derive macros
|
||||
7. **Excellent JSON**: serde_json provides robust JSON handling
|
||||
327
examples/plugins/hello-world-wasm-rust/src/lib.rs
Normal file
327
examples/plugins/hello-world-wasm-rust/src/lib.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
//! Bifrost WASM Plugin for Rust
|
||||
//!
|
||||
//! This plugin demonstrates the proper structure for parsing inputs,
|
||||
//! building responses, and handling context - similar to Go plugin patterns.
|
||||
//!
|
||||
//! Build with: cargo build --release --target wasm32-unknown-unknown
|
||||
|
||||
mod memory;
|
||||
mod types;
|
||||
|
||||
use memory::{read_string, write_string};
|
||||
use types::*;
|
||||
|
||||
// Global configuration storage
|
||||
static mut PLUGIN_CONFIG: Option<PluginConfig> = None;
|
||||
|
||||
// =============================================================================
|
||||
// Exported Plugin Functions
|
||||
// =============================================================================
|
||||
|
||||
/// Return the plugin name
|
||||
#[no_mangle]
|
||||
pub extern "C" fn get_name() -> u64 {
|
||||
write_string("hello-world-wasm-rust")
|
||||
}
|
||||
|
||||
/// Initialize the plugin with config
|
||||
/// Returns 0 on success, non-zero on error
|
||||
#[no_mangle]
|
||||
pub extern "C" fn init(config_ptr: u32, config_len: u32) -> i32 {
|
||||
let config_str = read_string(config_ptr, config_len);
|
||||
|
||||
// Parse configuration
|
||||
let config: PluginConfig = if config_str.is_empty() {
|
||||
PluginConfig::default()
|
||||
} else {
|
||||
match serde_json::from_str(&config_str) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return 1, // Config parse error
|
||||
}
|
||||
};
|
||||
|
||||
// Store configuration
|
||||
unsafe {
|
||||
PLUGIN_CONFIG = Some(config);
|
||||
}
|
||||
|
||||
0 // Success
|
||||
}
|
||||
|
||||
/// HTTP transport intercept
|
||||
/// Called at the HTTP layer before request enters Bifrost core.
|
||||
/// Can modify headers, query params, or short-circuit with a response.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn http_intercept(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
|
||||
// Parse input
|
||||
let input: HTTPInterceptInput = match serde_json::from_str(&input_str) {
|
||||
Ok(i) => i,
|
||||
Err(e) => {
|
||||
// Include context around the error position for debugging
|
||||
let error_context = if let Some(col) = extract_column(&e.to_string()) {
|
||||
let start = col.saturating_sub(50);
|
||||
let end = (col + 50).min(input_str.len());
|
||||
format!(" | context: ...{}...", &input_str[start..end])
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let output = HTTPInterceptOutput {
|
||||
error: format!("Failed to parse input: {}{}", e, error_context),
|
||||
..Default::default()
|
||||
};
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Add context value like Go plugin does
|
||||
let mut context = input.context;
|
||||
context.set_value("from-http", serde_json::json!("123"));
|
||||
|
||||
// Create output with context and request preserved (pass-through)
|
||||
// Serialize request to Value to ensure proper JSON structure
|
||||
let request_value = serde_json::to_value(&input.request).ok();
|
||||
|
||||
let output = HTTPInterceptOutput {
|
||||
context: input.context,
|
||||
request: input.request,
|
||||
has_response: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Pass through
|
||||
write_string(&serde_json::to_string(&output).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Pre-request hook
|
||||
/// Called before request is sent to the provider.
|
||||
/// Can modify the request or short-circuit with a response/error.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
|
||||
// Parse input
|
||||
let input: PreHookInput = match serde_json::from_str(&input_str) {
|
||||
Ok(i) => i,
|
||||
Err(e) => {
|
||||
let output = PreHookOutput {
|
||||
error: format!("Failed to parse input: {}", e),
|
||||
..Default::default()
|
||||
};
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
};
|
||||
|
||||
// Create output with context preserved
|
||||
let mut output = PreHookOutput {
|
||||
context: input.context.clone(),
|
||||
request: input.request.clone(),
|
||||
has_short_circuit: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Get provider and model for potential modifications
|
||||
let (_provider, model) = input.get_provider_model();
|
||||
|
||||
// Example: Short-circuit with mock response for specific model
|
||||
// Uncomment to test:
|
||||
/*
|
||||
if model == "mock-model" {
|
||||
output.has_short_circuit = true;
|
||||
|
||||
let mock_response = BifrostResponse {
|
||||
chat_response: Some(BifrostChatResponse {
|
||||
id: format!("mock-{}", input.context.request_id.unwrap_or_default()),
|
||||
model: "mock-model".to_string(),
|
||||
choices: vec![ResponseChoice {
|
||||
index: 0,
|
||||
message: Some(ChatMessage {
|
||||
role: ChatMessageRole::Assistant,
|
||||
content: Some(ChatMessageContent::Text(
|
||||
"This is a mock response from the Rust WASM plugin!".to_string()
|
||||
)),
|
||||
..Default::default()
|
||||
}),
|
||||
finish_reason: Some("stop".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
usage: Some(LLMUsage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 15,
|
||||
total_tokens: 25,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
output.short_circuit = Some(LLMPluginShortCircuit {
|
||||
response: Some(mock_response),
|
||||
error: None,
|
||||
});
|
||||
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
*/
|
||||
|
||||
// Example: Short-circuit with rate limit error
|
||||
// Uncomment to test:
|
||||
/*
|
||||
if should_rate_limit(&input.context) {
|
||||
output.has_short_circuit = true;
|
||||
output.short_circuit = Some(LLMPluginShortCircuit {
|
||||
response: None,
|
||||
error: Some(
|
||||
BifrostError::new("Rate limit exceeded")
|
||||
.with_type("rate_limit")
|
||||
.with_code("429")
|
||||
.with_status(429)
|
||||
),
|
||||
});
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
*/
|
||||
|
||||
// Silence unused variable warning in example code
|
||||
let _ = model;
|
||||
|
||||
// Pass through - empty request means use original
|
||||
write_string(&serde_json::to_string(&output).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Post-response hook
|
||||
/// Called after response is received from provider.
|
||||
/// Can modify the response or error.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
|
||||
// Parse input
|
||||
let input: PostHookInput = match serde_json::from_str(&input_str) {
|
||||
Ok(i) => i,
|
||||
Err(e) => {
|
||||
let output = PostHookOutput {
|
||||
hook_error: format!("Failed to parse input: {}", e),
|
||||
..Default::default()
|
||||
};
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
};
|
||||
|
||||
// Add context value like Go plugin does
|
||||
let mut context = input.context.clone();
|
||||
context.set_value("from-post-hook", serde_json::json!("456"));
|
||||
|
||||
// Create output with context and response/error preserved (pass-through)
|
||||
// This matches Go plugin behavior exactly
|
||||
let output = PostHookOutput {
|
||||
context,
|
||||
response: Some(input.response.clone()),
|
||||
error: Some(input.error.clone()),
|
||||
has_error: input.has_error,
|
||||
hook_error: String::new(),
|
||||
};
|
||||
|
||||
// Example: Modify error message when has_error is true
|
||||
// Uncomment to test:
|
||||
/*
|
||||
if input.has_error {
|
||||
if let Some(mut error) = input.parse_error() {
|
||||
error.error.message = format!("{} (processed by Rust WASM plugin)", error.error.message);
|
||||
let mut output = output;
|
||||
output.error = Some(serde_json::to_value(&error).unwrap_or_default());
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Example: Modify response
|
||||
// Uncomment to test:
|
||||
/*
|
||||
if let Some(mut response) = input.parse_response() {
|
||||
// Add custom metadata, modify model name, etc.
|
||||
if let Some(ref mut chat) = response.chat_response {
|
||||
// Add a marker to the model name
|
||||
chat.model = format!("{} (via rust-wasm)", chat.model);
|
||||
}
|
||||
let mut output = output;
|
||||
output.response = Some(serde_json::to_value(&response).unwrap_or_default());
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
*/
|
||||
|
||||
write_string(&serde_json::to_string(&output).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// HTTP stream chunk hook
|
||||
/// Called for each chunk during streaming responses.
|
||||
/// Can modify, skip, or stop streaming based on return values.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn http_stream_chunk_hook(input_ptr: u32, input_len: u32) -> u64 {
|
||||
let input_str = read_string(input_ptr, input_len);
|
||||
|
||||
// Parse input
|
||||
let input: HTTPStreamChunkHookInput = match serde_json::from_str(&input_str) {
|
||||
Ok(i) => i,
|
||||
Err(e) => {
|
||||
let output = HTTPStreamChunkHookOutput {
|
||||
error: format!("Failed to parse input: {}", e),
|
||||
..Default::default()
|
||||
};
|
||||
return write_string(&serde_json::to_string(&output).unwrap_or_default());
|
||||
}
|
||||
};
|
||||
|
||||
// Add context value like Go plugin does
|
||||
let mut context = input.context.clone();
|
||||
context.set_value("from-stream-chunk", serde_json::json!("rust-wasm"));
|
||||
|
||||
// Pass through chunk unchanged
|
||||
let output = HTTPStreamChunkHookOutput {
|
||||
context,
|
||||
chunk: Some(input.chunk),
|
||||
has_chunk: true,
|
||||
skip: false,
|
||||
error: String::new(),
|
||||
};
|
||||
|
||||
write_string(&serde_json::to_string(&output).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Cleanup resources
|
||||
/// Called when plugin is being unloaded.
|
||||
/// Returns 0 on success, non-zero on error
|
||||
#[no_mangle]
|
||||
pub extern "C" fn cleanup() -> i32 {
|
||||
// Clear stored configuration
|
||||
unsafe {
|
||||
PLUGIN_CONFIG = None;
|
||||
}
|
||||
|
||||
0 // Success
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Helper Functions
|
||||
// =============================================================================
|
||||
|
||||
/// Extract column number from serde error message for debugging
|
||||
fn extract_column(error_msg: &str) -> Option<usize> {
|
||||
// Error format: "... at line X column Y"
|
||||
if let Some(idx) = error_msg.rfind("column ") {
|
||||
let col_str = &error_msg[idx + 7..];
|
||||
col_str.split_whitespace().next()?.parse().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Example rate limit check function
|
||||
#[allow(dead_code)]
|
||||
fn should_rate_limit(_context: &BifrostContext) -> bool {
|
||||
// Implement your rate limiting logic here
|
||||
false
|
||||
}
|
||||
70
examples/plugins/hello-world-wasm-rust/src/memory.rs
Normal file
70
examples/plugins/hello-world-wasm-rust/src/memory.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
//! Memory management utilities for WASM plugins.
|
||||
//! Handles allocation, deallocation, and string read/write operations.
|
||||
|
||||
use std::alloc::{alloc, dealloc, Layout};
|
||||
use std::slice;
|
||||
|
||||
/// Pack a pointer and length into a single u64
|
||||
/// Upper 32 bits: pointer, Lower 32 bits: length
|
||||
pub fn pack_result(ptr: u32, len: u32) -> u64 {
|
||||
((ptr as u64) << 32) | (len as u64)
|
||||
}
|
||||
|
||||
/// Write a string to WASM memory and return packed pointer+length
|
||||
pub fn write_string(s: &str) -> u64 {
|
||||
if s.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let bytes = s.as_bytes();
|
||||
let ptr = unsafe { malloc(bytes.len() as u32) };
|
||||
if ptr == 0 {
|
||||
return 0;
|
||||
}
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr as *mut u8, bytes.len());
|
||||
}
|
||||
pack_result(ptr, bytes.len() as u32)
|
||||
}
|
||||
|
||||
/// Read a string from WASM memory given pointer and length
|
||||
pub fn read_string(ptr: u32, len: u32) -> String {
|
||||
if len == 0 {
|
||||
return String::new();
|
||||
}
|
||||
let bytes = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) };
|
||||
String::from_utf8_lossy(bytes).into_owned()
|
||||
}
|
||||
|
||||
/// Allocate memory for the host to write data
|
||||
///
|
||||
/// # Safety
|
||||
/// This function is marked as safe but performs unsafe operations internally.
|
||||
/// It is intended to be called from WASM host.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn malloc(size: u32) -> u32 {
|
||||
if size == 0 {
|
||||
return 0;
|
||||
}
|
||||
let layout = match Layout::from_size_align(size as usize, 1) {
|
||||
Ok(l) => l,
|
||||
Err(_) => return 0,
|
||||
};
|
||||
unsafe { alloc(layout) as u32 }
|
||||
}
|
||||
|
||||
/// Free allocated memory
|
||||
///
|
||||
/// # Safety
|
||||
/// This function is marked as safe but performs unsafe operations internally.
|
||||
/// It is intended to be called from WASM host.
|
||||
#[no_mangle]
|
||||
pub extern "C" fn free(ptr: u32, size: u32) {
|
||||
if ptr == 0 || size == 0 {
|
||||
return;
|
||||
}
|
||||
let layout = match Layout::from_size_align(size as usize, 1) {
|
||||
Ok(l) => l,
|
||||
Err(_) => return,
|
||||
};
|
||||
unsafe { dealloc(ptr as *mut u8, layout) }
|
||||
}
|
||||
870
examples/plugins/hello-world-wasm-rust/src/types.rs
Normal file
870
examples/plugins/hello-world-wasm-rust/src/types.rs
Normal file
@@ -0,0 +1,870 @@
|
||||
//! Type definitions for Bifrost WASM plugins.
|
||||
//! These structures mirror the Go SDK types for interoperability.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// =============================================================================
|
||||
// Nullable Deserializers
|
||||
// =============================================================================
|
||||
|
||||
/// Helper module for deserializing fields that may be null in JSON.
|
||||
/// Go's JSON encoder outputs `null` for nil slices/maps, but Rust's serde
|
||||
/// with `#[serde(default)]` only handles missing fields, not explicit nulls.
|
||||
mod nullable {
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Deserialize a string that may be null, converting null to empty string.
|
||||
pub fn string<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Deserialize a HashMap<String, String> that may be null or contain null values.
|
||||
/// Handles both `null` (entire map is null) and `{"key": null}` (value is null).
|
||||
pub fn string_map<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
// First deserialize as Option<HashMap<String, Option<String>>> to handle null values
|
||||
let opt_map: Option<HashMap<String, Option<String>>> = Option::deserialize(deserializer)?;
|
||||
|
||||
match opt_map {
|
||||
None => Ok(HashMap::new()),
|
||||
Some(map) => {
|
||||
// Filter out null values and unwrap the rest
|
||||
Ok(map
|
||||
.into_iter()
|
||||
.filter_map(|(k, v)| v.map(|val| (k, val)))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize an i32 that may be null, converting null to 0.
|
||||
pub fn i32_field<'de, D>(deserializer: D) -> Result<i32, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Option::<i32>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Deserialize an HTTPRequest that may be null, converting null to default.
|
||||
pub fn http_request<'de, D>(deserializer: D) -> Result<super::HTTPRequest, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Option::<super::HTTPRequest>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Deserialize a BifrostContext that may be null, converting null to default.
|
||||
pub fn context<'de, D>(deserializer: D) -> Result<super::BifrostContext, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Option::<super::BifrostContext>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Context Structure
|
||||
// =============================================================================
|
||||
|
||||
/// BifrostContext holds request-scoped values passed between hooks.
|
||||
/// This is a dynamic map (map[string]any in Go) that can hold any JSON values.
|
||||
/// Common keys include:
|
||||
/// - request_id: Unique identifier for the request
|
||||
/// - Custom plugin values can be added and will be persisted across hooks
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(transparent)]
|
||||
pub struct BifrostContext(pub HashMap<String, serde_json::Value>);
|
||||
|
||||
impl BifrostContext {
|
||||
pub fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
/// Set a custom value in the context
|
||||
pub fn set_value(&mut self, key: &str, value: impl Into<serde_json::Value>) {
|
||||
self.0.insert(key.to_string(), value.into());
|
||||
}
|
||||
|
||||
/// Get a value from the context
|
||||
pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
|
||||
self.0.get(key)
|
||||
}
|
||||
|
||||
/// Get a string value from the context
|
||||
pub fn get_string(&self, key: &str) -> Option<&str> {
|
||||
self.0.get(key).and_then(|v| v.as_str())
|
||||
}
|
||||
|
||||
/// Get a boolean value from the context
|
||||
pub fn get_bool(&self, key: &str) -> Option<bool> {
|
||||
self.0.get(key).and_then(|v| v.as_bool())
|
||||
}
|
||||
|
||||
/// Get an i64 value from the context
|
||||
pub fn get_i64(&self, key: &str) -> Option<i64> {
|
||||
self.0.get(key).and_then(|v| v.as_i64())
|
||||
}
|
||||
|
||||
/// Check if a key exists in the context
|
||||
pub fn contains_key(&self, key: &str) -> bool {
|
||||
self.0.contains_key(key)
|
||||
}
|
||||
|
||||
/// Remove a value from the context
|
||||
pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
|
||||
self.0.remove(key)
|
||||
}
|
||||
|
||||
/// Get the underlying HashMap for iteration
|
||||
pub fn inner(&self) -> &HashMap<String, serde_json::Value> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Get mutable access to the underlying HashMap
|
||||
pub fn inner_mut(&mut self) -> &mut HashMap<String, serde_json::Value> {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HTTP Transport Structures
|
||||
// =============================================================================
|
||||
|
||||
/// HTTPRequest represents an incoming HTTP request at the transport layer.
|
||||
/// Body is base64-encoded.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPRequest {
|
||||
#[serde(default, deserialize_with = "nullable::string")]
|
||||
pub method: String,
|
||||
|
||||
#[serde(default, deserialize_with = "nullable::string")]
|
||||
pub path: String,
|
||||
|
||||
#[serde(default, deserialize_with = "nullable::string_map")]
|
||||
pub headers: HashMap<String, String>,
|
||||
|
||||
#[serde(default, deserialize_with = "nullable::string_map")]
|
||||
pub query: HashMap<String, String>,
|
||||
|
||||
/// Base64-encoded request body
|
||||
#[serde(default, deserialize_with = "nullable::string")]
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
/// HTTPResponse represents an HTTP response to return.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPResponse {
|
||||
#[serde(default, deserialize_with = "nullable::i32_field")]
|
||||
pub status_code: i32,
|
||||
|
||||
#[serde(default, deserialize_with = "nullable::string_map")]
|
||||
pub headers: HashMap<String, String>,
|
||||
|
||||
/// Base64-encoded response body
|
||||
#[serde(default, deserialize_with = "nullable::string")]
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
/// HTTPInterceptInput is the input for http_intercept hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPInterceptInput {
|
||||
#[serde(default, deserialize_with = "nullable::context")]
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(default, deserialize_with = "nullable::http_request")]
|
||||
pub request: HTTPRequest,
|
||||
}
|
||||
|
||||
/// HTTPInterceptOutput is the output for http_intercept hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPInterceptOutput {
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request: Option<serde_json::Value>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response: Option<HTTPResponse>,
|
||||
|
||||
#[serde(default)]
|
||||
pub has_response: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Chat Completion Structures (BifrostRequest)
|
||||
// =============================================================================
|
||||
|
||||
/// ChatMessageRole represents the role of a message sender.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ChatMessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
Tool,
|
||||
Developer,
|
||||
}
|
||||
|
||||
impl Default for ChatMessageRole {
|
||||
fn default() -> Self {
|
||||
ChatMessageRole::User
|
||||
}
|
||||
}
|
||||
|
||||
/// ChatMessageContent can be either a string or an array of content blocks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessageContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ChatContentBlock>),
|
||||
}
|
||||
|
||||
impl Default for ChatMessageContent {
|
||||
fn default() -> Self {
|
||||
ChatMessageContent::Text(String::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// ChatContentBlock represents a content block in a message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
pub block_type: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub image_url: Option<ImageUrl>,
|
||||
}
|
||||
|
||||
/// ImageUrl represents an image URL in a content block.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
/// ChatMessage represents a message in the conversation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatMessage {
|
||||
#[serde(default)]
|
||||
pub role: ChatMessageRole,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<ChatMessageContent>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
/// ToolCall represents a tool call made by the assistant.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ToolCall {
|
||||
#[serde(default)]
|
||||
pub id: Option<String>,
|
||||
|
||||
#[serde(rename = "type", default)]
|
||||
pub call_type: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
/// ToolCallFunction represents the function being called.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ToolCallFunction {
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// ChatParameters contains optional parameters for chat completion.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatParameters {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_completion_tokens: Option<i32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f64>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ChatTool>>,
|
||||
|
||||
/// Catch-all for additional parameters
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// ChatTool represents a tool definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<ChatToolFunction>,
|
||||
}
|
||||
|
||||
/// ChatToolFunction represents a function definition.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatToolFunction {
|
||||
pub name: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// BifrostChatRequest represents a chat completion request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostChatRequest {
|
||||
#[serde(default)]
|
||||
pub provider: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub input: Vec<ChatMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<ChatParameters>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub fallbacks: Option<Vec<Fallback>>,
|
||||
}
|
||||
|
||||
/// Fallback represents a fallback provider/model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Fallback {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// BifrostRequest is the unified request structure.
|
||||
/// Only one of the request types should be present.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_request: Option<BifrostChatRequest>,
|
||||
|
||||
// Add other request types as needed
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl BifrostRequest {
|
||||
/// Get provider and model from the request
|
||||
pub fn get_provider_model(&self) -> (String, String) {
|
||||
if let Some(ref chat) = self.chat_request {
|
||||
return (chat.provider.clone(), chat.model.clone());
|
||||
}
|
||||
(String::new(), String::new())
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Response Structures (BifrostResponse)
|
||||
// =============================================================================
|
||||
|
||||
/// LLMUsage contains token usage information.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct LLMUsage {
|
||||
#[serde(default)]
|
||||
pub prompt_tokens: i32,
|
||||
|
||||
#[serde(default)]
|
||||
pub completion_tokens: i32,
|
||||
|
||||
#[serde(default)]
|
||||
pub total_tokens: i32,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_tokens_details: Option<serde_json::Value>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completion_tokens_details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// ResponseChoice represents a single completion choice.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ResponseChoice {
|
||||
#[serde(default)]
|
||||
pub index: i32,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub message: Option<ChatMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub delta: Option<ChatMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub finish_reason: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// BifrostChatResponse represents a chat completion response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostChatResponse {
|
||||
#[serde(default)]
|
||||
pub id: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub choices: Vec<ResponseChoice>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<LLMUsage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub created: Option<i64>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub object: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
/// BifrostResponse is the unified response structure.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_response: Option<BifrostChatResponse>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Error Structure
|
||||
// =============================================================================
|
||||
|
||||
/// ErrorField contains the error details.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ErrorField {
|
||||
#[serde(default)]
|
||||
pub message: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||
pub error_type: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub param: Option<String>,
|
||||
}
|
||||
|
||||
/// BifrostError represents an error response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct BifrostError {
|
||||
#[serde(default)]
|
||||
pub error: ErrorField,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub status_code: Option<i32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_fallbacks: Option<bool>,
|
||||
}
|
||||
|
||||
impl BifrostError {
|
||||
/// Create a new error with a message
|
||||
pub fn new(message: &str) -> Self {
|
||||
Self {
|
||||
error: ErrorField {
|
||||
message: message.to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the error type
|
||||
pub fn with_type(mut self, error_type: &str) -> Self {
|
||||
self.error.error_type = Some(error_type.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the error code
|
||||
pub fn with_code(mut self, code: &str) -> Self {
|
||||
self.error.code = Some(code.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the status code
|
||||
pub fn with_status(mut self, status: i32) -> Self {
|
||||
self.status_code = Some(status);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Short Circuit Structure
|
||||
// =============================================================================
|
||||
|
||||
/// LLMPluginShortCircuit allows plugins to short-circuit the request flow.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct LLMPluginShortCircuit {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response: Option<BifrostResponse>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<BifrostError>,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Hook Input/Output Structures
|
||||
// =============================================================================
|
||||
|
||||
/// PreHookInput is the input for pre_hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PreHookInput {
|
||||
#[serde(default)]
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(default)]
|
||||
pub request: serde_json::Value,
|
||||
}
|
||||
|
||||
impl PreHookInput {
|
||||
/// Parse the request as a BifrostRequest
|
||||
pub fn parse_request(&self) -> Option<BifrostRequest> {
|
||||
serde_json::from_value(self.request.clone()).ok()
|
||||
}
|
||||
|
||||
/// Get provider and model from the request
|
||||
pub fn get_provider_model(&self) -> (String, String) {
|
||||
if let Some(req) = self.parse_request() {
|
||||
return req.get_provider_model();
|
||||
}
|
||||
// Try direct access for simpler structures
|
||||
let provider = self.request.get("provider")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let model = self.request.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
(provider, model)
|
||||
}
|
||||
}
|
||||
|
||||
/// PreHookOutput is the output for pre_hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PreHookOutput {
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request: Option<serde_json::Value>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub short_circuit: Option<LLMPluginShortCircuit>,
|
||||
|
||||
#[serde(default)]
|
||||
pub has_short_circuit: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
/// PostHookInput is the input for post_hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PostHookInput {
|
||||
#[serde(default)]
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(default)]
|
||||
pub response: serde_json::Value,
|
||||
|
||||
#[serde(default)]
|
||||
pub error: serde_json::Value,
|
||||
|
||||
#[serde(default)]
|
||||
pub has_error: bool,
|
||||
}
|
||||
|
||||
impl PostHookInput {
|
||||
/// Parse the response as a BifrostResponse
|
||||
pub fn parse_response(&self) -> Option<BifrostResponse> {
|
||||
serde_json::from_value(self.response.clone()).ok()
|
||||
}
|
||||
|
||||
/// Parse the error as a BifrostError
|
||||
pub fn parse_error(&self) -> Option<BifrostError> {
|
||||
if self.has_error {
|
||||
serde_json::from_value(self.error.clone()).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// PostHookOutput is the output for post_hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PostHookOutput {
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response: Option<serde_json::Value>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<serde_json::Value>,
|
||||
|
||||
#[serde(default)]
|
||||
pub has_error: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub hook_error: String,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HTTP Stream Chunk Hook Input/Output Structures
|
||||
// =============================================================================
|
||||
|
||||
/// HTTPStreamChunkHookInput is the input for http_stream_chunk_hook.
|
||||
/// Called for each chunk during streaming responses.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPStreamChunkHookInput {
|
||||
#[serde(default)]
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(default)]
|
||||
pub request: serde_json::Value,
|
||||
|
||||
#[serde(default)]
|
||||
pub chunk: serde_json::Value, // BifrostStreamChunk as JSON
|
||||
}
|
||||
|
||||
/// HTTPStreamChunkHookOutput is the output for http_stream_chunk_hook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct HTTPStreamChunkHookOutput {
|
||||
pub context: BifrostContext,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunk: Option<serde_json::Value>, // BifrostStreamChunk as JSON, None to skip
|
||||
|
||||
#[serde(default)]
|
||||
pub has_chunk: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub skip: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Plugin Configuration
|
||||
// =============================================================================
|
||||
|
||||
/// Plugin configuration (customize as needed)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PluginConfig {
|
||||
#[serde(flatten)]
|
||||
pub values: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tests
|
||||
// =============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_context_serialization() {
|
||||
let mut ctx = BifrostContext::new();
|
||||
ctx.set_value("request_id", "test-123");
|
||||
ctx.set_value("custom_key", "custom_value");
|
||||
ctx.set_value("is_enabled", true);
|
||||
ctx.set_value("count", 42);
|
||||
|
||||
let json = serde_json::to_string(&ctx).unwrap();
|
||||
assert!(json.contains("request_id"));
|
||||
assert!(json.contains("custom_key"));
|
||||
assert!(json.contains("is_enabled"));
|
||||
assert!(json.contains("count"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_deserialization() {
|
||||
let json = r#"{"request_id": "test-123", "custom_key": "custom_value", "is_enabled": true}"#;
|
||||
let ctx: BifrostContext = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(ctx.get_string("request_id"), Some("test-123"));
|
||||
assert_eq!(ctx.get_string("custom_key"), Some("custom_value"));
|
||||
assert_eq!(ctx.get_bool("is_enabled"), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_methods() {
|
||||
let mut ctx = BifrostContext::new();
|
||||
ctx.set_value("key1", "value1");
|
||||
ctx.set_value("enabled", true);
|
||||
ctx.set_value("count", 42);
|
||||
|
||||
assert_eq!(ctx.get_string("key1"), Some("value1"));
|
||||
assert_eq!(ctx.get_bool("enabled"), Some(true));
|
||||
assert_eq!(ctx.get_i64("count"), Some(42));
|
||||
assert!(ctx.contains_key("key1"));
|
||||
assert!(!ctx.contains_key("nonexistent"));
|
||||
|
||||
ctx.remove("key1");
|
||||
assert!(!ctx.contains_key("key1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message() {
|
||||
let msg = ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: Some(ChatMessageContent::Text("Hello!".to_string())),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains("user"));
|
||||
assert!(json.contains("Hello!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bifrost_error() {
|
||||
let error = BifrostError::new("Test error")
|
||||
.with_type("test_type")
|
||||
.with_code("500")
|
||||
.with_status(500);
|
||||
|
||||
let json = serde_json::to_string(&error).unwrap();
|
||||
assert!(json.contains("Test error"));
|
||||
assert!(json.contains("test_type"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pre_hook_input_parsing() {
|
||||
let json = r#"{
|
||||
"context": {"request_id": "test-123", "custom": "value"},
|
||||
"request": {"provider": "openai", "model": "gpt-4"}
|
||||
}"#;
|
||||
|
||||
let input: PreHookInput = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(input.context.get_string("request_id"), Some("test-123"));
|
||||
assert_eq!(input.context.get_string("custom"), Some("value"));
|
||||
|
||||
let (provider, model) = input.get_provider_model();
|
||||
assert_eq!(provider, "openai");
|
||||
assert_eq!(model, "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_request_with_null_fields() {
|
||||
// Simulates Go sending null for nil []byte and nil maps
|
||||
let json = r#"{
|
||||
"method": "POST",
|
||||
"path": "/v1/chat/completions",
|
||||
"headers": null,
|
||||
"query": null,
|
||||
"body": null
|
||||
}"#;
|
||||
|
||||
let req: HTTPRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.method, "POST");
|
||||
assert_eq!(req.path, "/v1/chat/completions");
|
||||
assert!(req.headers.is_empty());
|
||||
assert!(req.query.is_empty());
|
||||
assert_eq!(req.body, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_request_with_missing_fields() {
|
||||
// Test that missing fields also work (default behavior)
|
||||
let json = r#"{
|
||||
"method": "GET",
|
||||
"path": "/health"
|
||||
}"#;
|
||||
|
||||
let req: HTTPRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(req.method, "GET");
|
||||
assert_eq!(req.path, "/health");
|
||||
assert!(req.headers.is_empty());
|
||||
assert!(req.query.is_empty());
|
||||
assert_eq!(req.body, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_intercept_input_with_nulls() {
|
||||
// Simulates a full HTTP intercept input with null body from Go
|
||||
let json = r#"{
|
||||
"context": {"request_id": "abc-123"},
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"path": "/v1/chat/completions",
|
||||
"headers": {"content-type": "application/json"},
|
||||
"query": {},
|
||||
"body": null
|
||||
}
|
||||
}"#;
|
||||
|
||||
let input: HTTPInterceptInput = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(input.context.get_string("request_id"), Some("abc-123"));
|
||||
assert_eq!(input.request.method, "POST");
|
||||
assert_eq!(input.request.path, "/v1/chat/completions");
|
||||
assert_eq!(input.request.headers.get("content-type"), Some(&"application/json".to_string()));
|
||||
assert_eq!(input.request.body, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_response_with_null_fields() {
|
||||
let json = r#"{
|
||||
"status_code": null,
|
||||
"headers": null,
|
||||
"body": null
|
||||
}"#;
|
||||
|
||||
let resp: HTTPResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.status_code, 0);
|
||||
assert!(resp.headers.is_empty());
|
||||
assert_eq!(resp.body, "");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user