first commit
This commit is contained in:
107
transports/Dockerfile
Normal file
107
transports/Dockerfile
Normal file
@@ -0,0 +1,107 @@
|
||||
# --- UI Build Stage: Build the React + Vite frontend ---
|
||||
FROM node:25-alpine3.23@sha256:cf38e1f3c28ac9d81cdc0c51d8220320b3b618780e44ef96a39f76f7dbfef023 AS ui-builder
|
||||
WORKDIR /app
|
||||
ENV CI=1
|
||||
ENV NODE_OPTIONS=--max-old-space-size=4096
|
||||
|
||||
# Copy UI package files and install dependencies
|
||||
COPY ui/package*.json ./
|
||||
RUN npm ci
|
||||
|
||||
# Copy UI source code
|
||||
COPY ui/ ./
|
||||
|
||||
# Build UI (skip the copy-build step)
|
||||
RUN npm run build-enterprise -- --debug
|
||||
# Skip the copy-build step since we'll copy the files in the Go build stage
|
||||
|
||||
# --- Go Build Stage: Compile the Go binary ---
|
||||
FROM golang:1.26.2-alpine3.23@sha256:f85330846cde1e57ca9ec309382da3b8e6ae3ab943d2739500e08c86393a21b1 AS builder
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies including gcc for CGO and sqlite
|
||||
RUN apk add --no-cache gcc musl-dev sqlite-dev binutils binutils-gold
|
||||
|
||||
# Set environment for CGO-enabled build (required for go-sqlite3)
|
||||
ENV CGO_ENABLED=1 GOOS=linux
|
||||
|
||||
COPY transports/go.mod transports/go.sum ./
|
||||
RUN ls
|
||||
RUN cat go.mod
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code and dependencies
|
||||
COPY transports/ ./
|
||||
|
||||
COPY --from=ui-builder /app/out ./bifrost-http/ui
|
||||
|
||||
# Build the binary with CGO enabled and static SQLite linking
|
||||
ENV GOWORK=off
|
||||
ARG VERSION=unknown
|
||||
RUN go build \
|
||||
-ldflags="-w -s -X main.Version=v${VERSION} -extldflags '-static'" \
|
||||
-a -trimpath \
|
||||
-tags "sqlite_static" \
|
||||
-o /app/main \
|
||||
./bifrost-http
|
||||
|
||||
# Verify build succeeded
|
||||
RUN test -f /app/main || (echo "Build failed" && exit 1)
|
||||
|
||||
# --- Runtime Stage: Minimal runtime image ---
|
||||
FROM alpine:3.23.3@sha256:25109184c71bdad752c8312a8623239686a9a2071e8825f20acb8f2198c3f659
|
||||
WORKDIR /app
|
||||
|
||||
# Install runtime dependencies for CGO-enabled binary
|
||||
# musl: C standard library (required for CGO binaries)
|
||||
# libgcc: GCC runtime library
|
||||
# ca-certificates: For HTTPS connections
|
||||
RUN apk add --no-cache musl libgcc ca-certificates wget zlib=1.3.2-r0
|
||||
|
||||
# Create data directory and set up user
|
||||
COPY --from=builder /app/main .
|
||||
COPY --from=builder /app/docker-entrypoint.sh .
|
||||
|
||||
# Getting arguments
|
||||
ARG ARG_APP_PORT=8080
|
||||
ARG ARG_APP_HOST=0.0.0.0
|
||||
ARG ARG_LOG_LEVEL=info
|
||||
ARG ARG_LOG_STYLE=json
|
||||
ARG ARG_APP_DIR=/app/data
|
||||
|
||||
# Environment variables with defaults (can be overridden at runtime)
|
||||
ENV APP_PORT=$ARG_APP_PORT \
|
||||
APP_HOST=$ARG_APP_HOST \
|
||||
LOG_LEVEL=$ARG_LOG_LEVEL \
|
||||
LOG_STYLE=$ARG_LOG_STYLE \
|
||||
APP_DIR=$ARG_APP_DIR
|
||||
|
||||
# Go runtime performance tuning (override at runtime for your workload)
|
||||
# GOGC: GC target percentage. Higher = less frequent GC, more memory usage.
|
||||
# Default: 100. For high-throughput with available memory, try 200-400.
|
||||
# GOMEMLIMIT: Soft memory limit for Go runtime. Set to ~90% of container memory limit.
|
||||
# Example: "1800MiB" for a 2GB container, "3600MiB" for 4GB.
|
||||
# When set, Go will be more aggressive about GC as it approaches this limit.
|
||||
# Note: GOMAXPROCS is automatically detected from cgroup CPU limits via automaxprocs.
|
||||
ENV GOGC="" \
|
||||
GOMEMLIMIT=""
|
||||
|
||||
|
||||
RUN mkdir -p $APP_DIR/logs && \
|
||||
adduser -D -s /bin/sh appuser && \
|
||||
chown -R appuser:appuser /app && \
|
||||
chmod +x /app/docker-entrypoint.sh
|
||||
USER appuser
|
||||
|
||||
|
||||
# Declare volume for data persistence
|
||||
VOLUME ["/app/data"]
|
||||
EXPOSE $APP_PORT
|
||||
|
||||
# Health check for container status monitoring
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 -O /dev/null http://127.0.0.1:${APP_PORT}/health || exit 1
|
||||
|
||||
# Use entrypoint script that handles volume permissions and argument processing
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/main"]
|
||||
118
transports/Dockerfile.local
Normal file
118
transports/Dockerfile.local
Normal file
@@ -0,0 +1,118 @@
|
||||
# Dockerfile.local - Uses local module sources via go workspace
|
||||
# For pre-release CI builds where module versions aren't published yet
|
||||
|
||||
# --- UI Build Stage: Build the React + Vite frontend ---
|
||||
FROM node:25-alpine3.23@sha256:cf38e1f3c28ac9d81cdc0c51d8220320b3b618780e44ef96a39f76f7dbfef023 AS ui-builder
|
||||
WORKDIR /app
|
||||
ENV CI=1
|
||||
ENV NODE_OPTIONS=--max-old-space-size=4096
|
||||
|
||||
# Copy UI package files and install dependencies
|
||||
COPY ui/package*.json ./
|
||||
RUN npm ci
|
||||
|
||||
# Copy UI source code
|
||||
COPY ui/ ./
|
||||
|
||||
# Build UI (skip the copy-build step)
|
||||
RUN npm run build-enterprise -- --debug
|
||||
# Skip the copy-build step since we'll copy the files in the Go build stage
|
||||
|
||||
# --- Go Build Stage: Compile the Go binary using local modules ---
|
||||
FROM golang:1.26.2-alpine3.23@sha256:f85330846cde1e57ca9ec309382da3b8e6ae3ab943d2739500e08c86393a21b1 AS builder
|
||||
WORKDIR /build
|
||||
|
||||
# Install dependencies including gcc for CGO and sqlite
|
||||
RUN apk add --no-cache gcc musl-dev sqlite-dev binutils binutils-gold
|
||||
|
||||
# Set environment for CGO-enabled build (required for go-sqlite3)
|
||||
ENV CGO_ENABLED=1 GOOS=linux
|
||||
|
||||
# Copy all local modules
|
||||
COPY core/ ./core/
|
||||
COPY framework/ ./framework/
|
||||
COPY plugins/ ./plugins/
|
||||
COPY transports/ ./transports/
|
||||
|
||||
# Set up go workspace to resolve local module dependencies
|
||||
RUN go work init && \
|
||||
go work use ./core && \
|
||||
go work use ./framework && \
|
||||
go work use ./plugins/compat && \
|
||||
go work use ./plugins/governance && \
|
||||
go work use ./plugins/jsonparser && \
|
||||
go work use ./plugins/logging && \
|
||||
go work use ./plugins/maxim && \
|
||||
go work use ./plugins/mocker && \
|
||||
go work use ./plugins/otel && \
|
||||
go work use ./plugins/prompts && \
|
||||
go work use ./plugins/semanticcache && \
|
||||
go work use ./plugins/telemetry && \
|
||||
go work use ./transports
|
||||
|
||||
# Download external (non-local) dependencies
|
||||
RUN cd /build/transports && go mod download
|
||||
|
||||
# Copy UI build output into transports
|
||||
COPY --from=ui-builder /app/out ./transports/bifrost-http/ui
|
||||
|
||||
# Build the binary with CGO enabled and static SQLite linking
|
||||
ARG VERSION=unknown
|
||||
RUN cd /build/transports && \
|
||||
go build \
|
||||
-ldflags="-w -s -X main.Version=v${VERSION} -extldflags '-static'" \
|
||||
-a -trimpath \
|
||||
-tags "sqlite_static" \
|
||||
-o /app/main \
|
||||
./bifrost-http
|
||||
|
||||
# Verify build succeeded
|
||||
RUN test -f /app/main || (echo "Build failed" && exit 1)
|
||||
|
||||
# --- Runtime Stage: Minimal runtime image ---
|
||||
FROM alpine:3.23.3@sha256:25109184c71bdad752c8312a8623239686a9a2071e8825f20acb8f2198c3f659
|
||||
WORKDIR /app
|
||||
|
||||
# Install runtime dependencies for CGO-enabled binary
|
||||
# musl: C standard library (required for CGO binaries)
|
||||
# libgcc: GCC runtime library
|
||||
# ca-certificates: For HTTPS connections
|
||||
RUN apk add --no-cache musl libgcc ca-certificates wget zlib=1.3.2-r0
|
||||
|
||||
# Create data directory and set up user
|
||||
COPY --from=builder /app/main .
|
||||
COPY --from=builder /build/transports/docker-entrypoint.sh .
|
||||
|
||||
# Getting arguments
|
||||
ARG ARG_APP_PORT=8080
|
||||
ARG ARG_APP_HOST=0.0.0.0
|
||||
ARG ARG_LOG_LEVEL=info
|
||||
ARG ARG_LOG_STYLE=json
|
||||
ARG ARG_APP_DIR=/app/data
|
||||
|
||||
# Environment variables with defaults (can be overridden at runtime)
|
||||
ENV APP_PORT=$ARG_APP_PORT \
|
||||
APP_HOST=$ARG_APP_HOST \
|
||||
LOG_LEVEL=$ARG_LOG_LEVEL \
|
||||
LOG_STYLE=$ARG_LOG_STYLE \
|
||||
APP_DIR=$ARG_APP_DIR
|
||||
|
||||
|
||||
RUN mkdir -p $APP_DIR/logs && \
|
||||
adduser -D -s /bin/sh appuser && \
|
||||
chown -R appuser:appuser /app && \
|
||||
chmod +x /app/docker-entrypoint.sh
|
||||
USER appuser
|
||||
|
||||
|
||||
# Declare volume for data persistence
|
||||
VOLUME ["/app/data"]
|
||||
EXPOSE $APP_PORT
|
||||
|
||||
# Health check for container status monitoring
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 -O /dev/null http://127.0.0.1:${APP_PORT}/health || exit 1
|
||||
|
||||
# Use entrypoint script that handles volume permissions and argument processing
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["/app/main"]
|
||||
166
transports/README.md
Normal file
166
transports/README.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# Bifrost Gateway
|
||||
|
||||
Bifrost Gateway is a blazing-fast HTTP API that unifies access to 15+ AI providers (OpenAI, Anthropic, AWS Bedrock, Google Vertex, and more) through a single OpenAI-compatible interface. Deploy in seconds with zero configuration and get automatic fallbacks, semantic caching, tool calling, and enterprise-grade features.
|
||||
|
||||
**Complete Documentation**: [https://docs.getbifrost.ai](https://docs.getbifrost.ai)
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
|
||||
Choose your preferred method:
|
||||
|
||||
#### NPX (Recommended)
|
||||
|
||||
```bash
|
||||
# Install and run locally
|
||||
npx -y @maximhq/bifrost
|
||||
|
||||
# Open web interface at http://localhost:8080
|
||||
```
|
||||
|
||||
#### Docker
|
||||
|
||||
```bash
|
||||
# Pull and run Bifrost Gateway
|
||||
docker pull maximhq/bifrost
|
||||
docker run -p 8080:8080 maximhq/bifrost
|
||||
|
||||
# For persistent configuration
|
||||
docker run -p 8080:8080 -v $(pwd)/data:/app/data maximhq/bifrost
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Bifrost starts with zero configuration needed. Configure providers through the **built-in web UI** at `http://localhost:8080` or via API:
|
||||
|
||||
```bash
|
||||
# Add OpenAI provider via API
|
||||
curl -X POST http://localhost:8080/api/providers \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"provider": "openai",
|
||||
"keys": [{"value": "sk-your-openai-key", "models": ["gpt-4o-mini"], "weight": 1.0}]
|
||||
}'
|
||||
```
|
||||
|
||||
For file-based configuration, create `config.json` in your app directory:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"openai": {
|
||||
"keys": [{"value": "env.OPENAI_API_KEY", "models": ["gpt-4o-mini"], "weight": 1.0}]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Your First API Call
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "Hello, Bifrost!"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**That's it!** You now have a unified AI gateway running locally.
|
||||
|
||||
---
|
||||
|
||||
## Key Features
|
||||
|
||||
Bifrost Gateway provides enterprise-grade AI infrastructure with these core capabilities:
|
||||
|
||||
### Core Features
|
||||
|
||||
- **[Unified Interface](https://docs.getbifrost.ai/features/unified-interface)** - Single OpenAI-compatible API for all providers
|
||||
- **[Multi-Provider Support](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration)** - OpenAI, Anthropic, AWS Bedrock, Google Vertex, Cerebras, Azure, Cohere, Mistral, Ollama, Groq, and more
|
||||
- **[Drop-in Replacement](https://docs.getbifrost.ai/features/drop-in-replacement)** - Replace OpenAI/Anthropic/GenAI SDKs with zero code changes
|
||||
- **[Automatic Fallbacks](https://docs.getbifrost.ai/features/fallbacks)** - Seamless failover between providers and models
|
||||
- **[Streaming Support](https://docs.getbifrost.ai/quickstart/gateway/streaming)** - Real-time response streaming for all providers
|
||||
|
||||
### Advanced Features
|
||||
|
||||
- **[Model Context Protocol (MCP)](https://docs.getbifrost.ai/features/mcp)** - Enable AI models to use external tools (filesystem, web search, databases)
|
||||
- **[Semantic Caching](https://docs.getbifrost.ai/features/semantic-caching)** - Intelligent response caching based on semantic similarity
|
||||
- **[Load Balancing](https://docs.getbifrost.ai/features/fallbacks)** - Distribute requests across multiple API keys and providers
|
||||
- **[Governance & Budget Management](https://docs.getbifrost.ai/features/governance)** - Usage tracking, rate limiting, and cost control
|
||||
- **[Custom Plugins](https://docs.getbifrost.ai/enterprise/custom-plugins)** - Extensible middleware for analytics, monitoring, and custom logic
|
||||
|
||||
### Enterprise Features
|
||||
|
||||
- **[Clustering](https://docs.getbifrost.ai/enterprise/clustering)** - Multi-node deployment with shared state
|
||||
- **[SSO Integration](https://docs.getbifrost.ai/features/sso-with-google-github)** - Google, GitHub authentication
|
||||
- **[Vault Support](https://docs.getbifrost.ai/enterprise/vault-support)** - Secure API key management
|
||||
- **[Custom Analytics](https://docs.getbifrost.ai/features/observability)** - Detailed usage insights and monitoring
|
||||
- **[In-VPC Deployments](https://docs.getbifrost.ai/enterprise/invpc-deployments)** - Private cloud deployment options
|
||||
|
||||
**Learn More**: [Complete Feature Documentation](https://docs.getbifrost.ai/features/unified-interface)
|
||||
|
||||
---
|
||||
|
||||
## SDK Integrations
|
||||
|
||||
Replace your existing SDK base URLs to unlock Bifrost's features instantly:
|
||||
|
||||
### OpenAI SDK
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
base_url="http://localhost:8080/openai",
|
||||
api_key="dummy" # Handled by Bifrost
|
||||
)
|
||||
```
|
||||
|
||||
### Anthropic SDK
|
||||
|
||||
```python
|
||||
import anthropic
|
||||
client = anthropic.Anthropic(
|
||||
base_url="http://localhost:8080/anthropic",
|
||||
api_key="dummy" # Handled by Bifrost
|
||||
)
|
||||
```
|
||||
|
||||
### Google GenAI SDK
|
||||
|
||||
```python
|
||||
import google.generativeai as genai
|
||||
genai.configure(
|
||||
transport="rest",
|
||||
api_endpoint="http://localhost:8080/genai",
|
||||
api_key="dummy" # Handled by Bifrost
|
||||
)
|
||||
```
|
||||
|
||||
**Complete Integration Guides**: [SDK Integrations](https://docs.getbifrost.ai/integrations/what-is-an-integration)
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
|
||||
### Getting Started
|
||||
|
||||
- [Quick Setup Guide](https://docs.getbifrost.ai/quickstart/gateway/setting-up) - Detailed installation and configuration
|
||||
- [Provider Configuration](https://docs.getbifrost.ai/quickstart/gateway/provider-configuration) - Connect multiple AI providers
|
||||
- [Integration Guide](https://docs.getbifrost.ai/quickstart/gateway/integrations) - SDK replacements
|
||||
|
||||
### Advanced Topics
|
||||
|
||||
- [MCP Tool Calling](https://docs.getbifrost.ai/features/mcp) - External tool integration
|
||||
- [Semantic Caching](https://docs.getbifrost.ai/features/semantic-caching) - Intelligent response caching
|
||||
- [Fallbacks & Load Balancing](https://docs.getbifrost.ai/features/fallbacks) - Reliability and scaling
|
||||
- [Budget Management](https://docs.getbifrost.ai/features/governance) - Cost control and governance
|
||||
|
||||
**Browse All Documentation**: [https://docs.getbifrost.ai](https://docs.getbifrost.ai)
|
||||
|
||||
---
|
||||
|
||||
*Built with ❤️ by [Maxim](https://getmaxim.ai)*
|
||||
66
transports/bifrost-http/.air.debug.toml
Normal file
66
transports/bifrost-http/.air.debug.toml
Normal file
@@ -0,0 +1,66 @@
|
||||
root = "../.."
|
||||
testdata_dir = "testdata"
|
||||
tmp_dir = "transports/bifrost-http/tmp"
|
||||
|
||||
[build]
|
||||
args_bin = []
|
||||
bin = "tmp/main"
|
||||
# Build with debug flags: -N disables optimizations, -l disables inlining
|
||||
# Note: We don't set GOWORK=off so it uses the workspace go.work file
|
||||
cmd = "go build -tags dev -gcflags='all=-N -l' -o ./tmp/main ."
|
||||
delay = 1000
|
||||
exclude_dir = [
|
||||
"assets",
|
||||
"tmp",
|
||||
"vendor",
|
||||
"testdata",
|
||||
"ui",
|
||||
"node_modules",
|
||||
"transports/bifrost-http/ui",
|
||||
"core/tests",
|
||||
"tests",
|
||||
"docs",
|
||||
"npx",
|
||||
]
|
||||
exclude_file = []
|
||||
exclude_regex = ["_test.go"]
|
||||
exclude_unchanged = false
|
||||
follow_symlink = false
|
||||
# Run binary via delve debugger in headless mode on port 2345
|
||||
full_bin = "dlv exec ./tmp/main --listen=127.0.0.1:2345 --headless=true --api-version=2 --accept-multiclient --log --"
|
||||
watch_dirs = ["."]
|
||||
include_dir = []
|
||||
include_ext = ["go", "tpl", "tmpl", "html"]
|
||||
include_file = []
|
||||
kill_delay = "1s"
|
||||
log = "tmp/build-errors.log"
|
||||
poll = false
|
||||
stop_on_error = true
|
||||
poll_interval = 0
|
||||
rerun = false
|
||||
rerun_delay = 500
|
||||
send_interrupt = true
|
||||
stop_on_root = false
|
||||
|
||||
[color]
|
||||
app = ""
|
||||
build = "yellow"
|
||||
main = "magenta"
|
||||
runner = "green"
|
||||
watcher = "cyan"
|
||||
|
||||
[log]
|
||||
main_only = false
|
||||
time = false
|
||||
|
||||
[misc]
|
||||
clean_on_exit = false
|
||||
|
||||
[proxy]
|
||||
enabled = false
|
||||
proxy_port = 8090
|
||||
app_port = 8080
|
||||
|
||||
[screen]
|
||||
clear_on_rebuild = false
|
||||
keep_scroll = true
|
||||
65
transports/bifrost-http/.air.toml
Normal file
65
transports/bifrost-http/.air.toml
Normal file
@@ -0,0 +1,65 @@
|
||||
root = "../.."
|
||||
testdata_dir = "testdata"
|
||||
tmp_dir = "transports/bifrost-http/tmp"
|
||||
|
||||
[build]
|
||||
args_bin = []
|
||||
bin = "tmp/main"
|
||||
cmd = "go build -tags dev -o ./tmp/main ."
|
||||
delay = 1000
|
||||
exclude_dir = [
|
||||
"assets",
|
||||
"tmp",
|
||||
"vendor",
|
||||
"testdata",
|
||||
"ui",
|
||||
"node_modules",
|
||||
"transports/bifrost-http/ui",
|
||||
"core/tests",
|
||||
"tests",
|
||||
"docs",
|
||||
"npx",
|
||||
"test-reports",
|
||||
"playwright-report",
|
||||
]
|
||||
exclude_file = []
|
||||
exclude_regex = ["_test.go"]
|
||||
exclude_unchanged = false
|
||||
follow_symlink = false
|
||||
full_bin = ""
|
||||
watch_dirs = ["."]
|
||||
include_dir = []
|
||||
include_ext = ["go", "tpl", "tmpl", "html"]
|
||||
include_file = []
|
||||
kill_delay = "1s"
|
||||
log = "tmp/build-errors.log"
|
||||
poll = false
|
||||
stop_on_error = true
|
||||
poll_interval = 0
|
||||
rerun = false
|
||||
rerun_delay = 500
|
||||
send_interrupt = true
|
||||
stop_on_root = false
|
||||
|
||||
[color]
|
||||
app = ""
|
||||
build = "yellow"
|
||||
main = "magenta"
|
||||
runner = "green"
|
||||
watcher = "cyan"
|
||||
|
||||
[log]
|
||||
main_only = false
|
||||
time = false
|
||||
|
||||
[misc]
|
||||
clean_on_exit = false
|
||||
|
||||
[proxy]
|
||||
enabled = false
|
||||
proxy_port = 8090
|
||||
app_port = 8080
|
||||
|
||||
[screen]
|
||||
clear_on_rebuild = false
|
||||
keep_scroll = true
|
||||
548
transports/bifrost-http/handlers/asyncinference.go
Normal file
548
transports/bifrost-http/handlers/asyncinference.go
Normal file
@@ -0,0 +1,548 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// --- HTTP Handler ---
|
||||
|
||||
// AsyncHandler handles async job HTTP endpoints.
|
||||
type AsyncHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
executor *logstore.AsyncJobExecutor
|
||||
handlerStore lib.HandlerStore
|
||||
config *lib.Config
|
||||
}
|
||||
|
||||
// AsyncPathToTypeMapping maps exact paths to request types (only for non-parameterized paths)
|
||||
// Parameterized paths are set per-route in RegisterRoutes
|
||||
var AsyncPathToTypeMapping = map[string]schemas.RequestType{
|
||||
"/v1/async/completions": schemas.TextCompletionRequest,
|
||||
"/v1/async/chat/completions": schemas.ChatCompletionRequest,
|
||||
"/v1/async/responses": schemas.ResponsesRequest,
|
||||
"/v1/async/embeddings": schemas.EmbeddingRequest,
|
||||
"/v1/async/audio/speech": schemas.SpeechRequest,
|
||||
"/v1/async/audio/transcriptions": schemas.TranscriptionRequest,
|
||||
"/v1/async/images/generations": schemas.ImageGenerationRequest,
|
||||
"/v1/async/images/edits": schemas.ImageEditRequest,
|
||||
"/v1/async/images/variations": schemas.ImageVariationRequest,
|
||||
"/v1/async/rerank": schemas.RerankRequest,
|
||||
"/v1/async/ocr": schemas.OCRRequest,
|
||||
}
|
||||
|
||||
// RegisterAsyncRequestTypeMiddleware handles exact path matching for non-parameterized routes
|
||||
func RegisterAsyncRequestTypeMiddleware(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
if requestType, ok := AsyncPathToTypeMapping[path]; ok {
|
||||
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, requestType)
|
||||
}
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// NewAsyncHandler creates a new AsyncHandler.
|
||||
// If the async job executor is not available (e.g., LogsStore or governance plugin not configured),
|
||||
// the handler is created with a nil executor and RegisterRoutes will skip async route registration.
|
||||
func NewAsyncHandler(client *bifrost.Bifrost, config *lib.Config) *AsyncHandler {
|
||||
return &AsyncHandler{
|
||||
client: client,
|
||||
executor: config.GetAsyncJobExecutor(),
|
||||
handlerStore: config,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers async job endpoints.
|
||||
func (h *AsyncHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
if h.executor == nil {
|
||||
return // LogStore not configured, skip async routes
|
||||
}
|
||||
|
||||
baseMiddlewares := append([]schemas.BifrostHTTPMiddleware{RegisterAsyncRequestTypeMiddleware}, middlewares...)
|
||||
|
||||
// Async submission endpoints (non-parameterized, request type set via AsyncPathToTypeMapping)
|
||||
r.POST("/v1/async/completions", lib.ChainMiddlewares(h.asyncTextCompletion, baseMiddlewares...))
|
||||
r.POST("/v1/async/chat/completions", lib.ChainMiddlewares(h.asyncChatCompletion, baseMiddlewares...))
|
||||
r.POST("/v1/async/responses", lib.ChainMiddlewares(h.asyncResponses, baseMiddlewares...))
|
||||
r.POST("/v1/async/embeddings", lib.ChainMiddlewares(h.asyncEmbeddings, baseMiddlewares...))
|
||||
r.POST("/v1/async/audio/speech", lib.ChainMiddlewares(h.asyncSpeech, baseMiddlewares...))
|
||||
r.POST("/v1/async/audio/transcriptions", lib.ChainMiddlewares(h.asyncTranscription, baseMiddlewares...))
|
||||
r.POST("/v1/async/images/generations", lib.ChainMiddlewares(h.asyncImageGeneration, baseMiddlewares...))
|
||||
r.POST("/v1/async/images/edits", lib.ChainMiddlewares(h.asyncImageEdit, baseMiddlewares...))
|
||||
r.POST("/v1/async/images/variations", lib.ChainMiddlewares(h.asyncImageVariation, baseMiddlewares...))
|
||||
r.POST("/v1/async/rerank", lib.ChainMiddlewares(h.asyncRerank, baseMiddlewares...))
|
||||
r.POST("/v1/async/ocr", lib.ChainMiddlewares(h.asyncOCR, baseMiddlewares...))
|
||||
|
||||
// Async job retrieval endpoints
|
||||
r.GET("/v1/async/completions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.TextCompletionRequest), middlewares...))
|
||||
r.GET("/v1/async/chat/completions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ChatCompletionRequest), middlewares...))
|
||||
r.GET("/v1/async/responses/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ResponsesRequest), middlewares...))
|
||||
r.GET("/v1/async/embeddings/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.EmbeddingRequest), middlewares...))
|
||||
r.GET("/v1/async/audio/speech/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.SpeechRequest), middlewares...))
|
||||
r.GET("/v1/async/audio/transcriptions/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.TranscriptionRequest), middlewares...))
|
||||
r.GET("/v1/async/images/generations/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageGenerationRequest), middlewares...))
|
||||
r.GET("/v1/async/images/edits/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageEditRequest), middlewares...))
|
||||
r.GET("/v1/async/images/variations/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.ImageVariationRequest), middlewares...))
|
||||
r.GET("/v1/async/rerank/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.RerankRequest), middlewares...))
|
||||
r.GET("/v1/async/ocr/{job_id}", lib.ChainMiddlewares(h.getJob(schemas.OCRRequest), middlewares...))
|
||||
}
|
||||
|
||||
// --- Async submission handlers ---
|
||||
|
||||
// asyncTextCompletion handles POST /v1/async/completions
|
||||
func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostTextReq, err := prepareTextCompletionRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async text completions")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.TextCompletionRequest(bgCtx, bifrostTextReq)
|
||||
},
|
||||
schemas.TextCompletionRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncChatCompletion handles POST /v1/async/chat/completions
|
||||
func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostChatReq, err := prepareChatCompletionRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async chat completions")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ChatCompletionRequest(bgCtx, bifrostChatReq)
|
||||
},
|
||||
schemas.ChatCompletionRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncResponses handles POST /v1/async/responses
|
||||
func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostResponsesReq, err := prepareResponsesRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async responses")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ResponsesRequest(bgCtx, bifrostResponsesReq)
|
||||
},
|
||||
schemas.ResponsesRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to create async job: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncEmbeddings handles POST /v1/async/embeddings
|
||||
func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) {
|
||||
_, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.EmbeddingRequest(bgCtx, bifrostEmbeddingReq)
|
||||
},
|
||||
schemas.EmbeddingRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncSpeech handles POST /v1/async/audio/speech
|
||||
func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostSpeechReq, err := prepareSpeechRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.StreamFormat != nil && *req.StreamFormat == "sse" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async speech")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.SpeechRequest(bgCtx, bifrostSpeechReq)
|
||||
},
|
||||
schemas.SpeechRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncTranscription handles POST /v1/async/audio/transcriptions
|
||||
func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) {
|
||||
bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async transcriptions")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.TranscriptionRequest(bgCtx, bifrostTranscriptionReq)
|
||||
},
|
||||
schemas.TranscriptionRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncImageGeneration handles POST /v1/async/images/generations
|
||||
func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostReq, err := prepareImageGenerationRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.BifrostParams.Stream != nil && *req.BifrostParams.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async image generations")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ImageGenerationRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.ImageGenerationRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncImageEdit handles POST /v1/async/images/edits
|
||||
func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) {
|
||||
req, bifrostReq, err := prepareImageEditRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream != nil && *req.Stream {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "stream is not supported for async image edits")
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ImageEditRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.ImageEditRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncImageVariation handles POST /v1/async/images/variations
|
||||
func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) {
|
||||
bifrostReq, err := prepareImageVariationRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.ImageVariationRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.ImageVariationRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncRerank handles POST /v1/async/rerank
|
||||
func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) {
|
||||
_, bifrostReq, err := prepareRerankRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.RerankRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.RerankRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// asyncOCR handles POST /v1/async/ocr
|
||||
func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) {
|
||||
_, bifrostReq, err := prepareOCRRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
resultTTL := getResultTTLFromHeaderWithDefault(ctx, h.config.ClientConfig.AsyncJobResultTTL)
|
||||
|
||||
job, err := h.executor.SubmitJob(
|
||||
bifrostCtx,
|
||||
resultTTL,
|
||||
func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) {
|
||||
return h.client.OCRRequest(bgCtx, bifrostReq)
|
||||
},
|
||||
schemas.OCRRequest,
|
||||
)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
SendJSONWithStatus(ctx, job.ToResponse(), fasthttp.StatusAccepted)
|
||||
}
|
||||
|
||||
// --- Job retrieval handler ---
|
||||
|
||||
// getJob handles GET /v1/async/{type}/{job_id}
|
||||
func (h *AsyncHandler) getJob(operationType schemas.RequestType) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
jobID, ok := ctx.UserValue("job_id").(string)
|
||||
if !ok || jobID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "job_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Get the requesting user's VK for auth check
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
job, err := h.executor.RetrieveJob(bifrostCtx, jobID, getVirtualKeyFromContext(bifrostCtx), operationType)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := job.ToResponse()
|
||||
|
||||
// Return 202 for pending/processing, 200 for completed/failed
|
||||
switch job.Status {
|
||||
case schemas.AsyncJobStatusPending, schemas.AsyncJobStatusProcessing:
|
||||
SendJSONWithStatus(ctx, resp, fasthttp.StatusAccepted)
|
||||
default:
|
||||
SendJSON(ctx, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
// getVirtualKeyFromContext extracts the virtual key value from context.
|
||||
// Returns nil if no VK is present (e.g., direct key mode or no governance).
|
||||
func getVirtualKeyFromContext(ctx *schemas.BifrostContext) *string {
|
||||
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
if vkValue == "" {
|
||||
return nil
|
||||
}
|
||||
return &vkValue
|
||||
}
|
||||
|
||||
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
|
||||
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
|
||||
if resultTTL == "" {
|
||||
return defaultTTL
|
||||
}
|
||||
resultTTLInt, err := strconv.Atoi(resultTTL)
|
||||
if err != nil || resultTTLInt < 0 {
|
||||
return defaultTTL
|
||||
}
|
||||
return resultTTLInt
|
||||
}
|
||||
61
transports/bifrost-http/handlers/cache.go
Normal file
61
transports/bifrost-http/handlers/cache.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type CacheHandler struct {
|
||||
plugin *semanticcache.Plugin
|
||||
}
|
||||
|
||||
func NewCacheHandler(plugin schemas.LLMPlugin) *CacheHandler {
|
||||
semanticCachePlugin, ok := plugin.(*semanticcache.Plugin)
|
||||
if !ok {
|
||||
logger.Fatal("Cache handler requires a semantic cache plugin")
|
||||
}
|
||||
|
||||
return &CacheHandler{
|
||||
plugin: semanticCachePlugin,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.DELETE("/api/cache/clear/{requestId}", lib.ChainMiddlewares(h.clearCache, middlewares...))
|
||||
r.DELETE("/api/cache/clear-by-key/{cacheKey}", lib.ChainMiddlewares(h.clearCacheByKey, middlewares...))
|
||||
}
|
||||
|
||||
func (h *CacheHandler) clearCache(ctx *fasthttp.RequestCtx) {
|
||||
requestID, ok := ctx.UserValue("requestId").(string)
|
||||
if !ok {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid request ID")
|
||||
return
|
||||
}
|
||||
if err := h.plugin.ClearCacheForRequestID(requestID); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache")
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Cache cleared successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *CacheHandler) clearCacheByKey(ctx *fasthttp.RequestCtx) {
|
||||
cacheKey, ok := ctx.UserValue("cacheKey").(string)
|
||||
if !ok {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid cache key")
|
||||
return
|
||||
}
|
||||
if err := h.plugin.ClearCacheForKey(cacheKey); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to clear cache")
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Cache cleared successfully",
|
||||
})
|
||||
}
|
||||
887
transports/bifrost-http/handlers/config.go
Normal file
887
transports/bifrost-http/handlers/config.go
Normal file
@@ -0,0 +1,887 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/network"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/maximhq/bifrost/plugins/compat"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// securityHeaders is the list of headers that cannot be configured in allowlist/denylist
|
||||
// These headers are always blocked for security reasons regardless of user configuration
|
||||
var securityHeaders = []string{
|
||||
"authorization",
|
||||
"proxy-authorization",
|
||||
"cookie",
|
||||
"host",
|
||||
"content-length",
|
||||
"connection",
|
||||
"transfer-encoding",
|
||||
"x-api-key",
|
||||
"x-goog-api-key",
|
||||
"x-bf-api-key",
|
||||
"x-bf-vk",
|
||||
}
|
||||
|
||||
// ConfigManager is the interface for the config manager
|
||||
type ConfigManager interface {
|
||||
UpdateAuthConfig(ctx context.Context, authConfig *configstore.AuthConfig) error
|
||||
ReloadClientConfigFromConfigStore(ctx context.Context) error
|
||||
UpdateSyncConfig(ctx context.Context) error
|
||||
ForceReloadPricing(ctx context.Context) error
|
||||
UpdateDropExcessRequests(ctx context.Context, value bool)
|
||||
UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error
|
||||
ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error
|
||||
RemovePlugin(ctx context.Context, name string) error
|
||||
ReloadProxyConfig(ctx context.Context, config *configstoreTables.GlobalProxyConfig) error
|
||||
ReloadHeaderFilterConfig(ctx context.Context, config *configstoreTables.GlobalHeaderFilterConfig) error
|
||||
}
|
||||
|
||||
// ConfigHandler manages runtime configuration updates for Bifrost.
|
||||
// It provides endpoints to update and retrieve settings persisted via the ConfigStore backed by sql database.
|
||||
type ConfigHandler struct {
|
||||
store *lib.Config
|
||||
configManager ConfigManager
|
||||
}
|
||||
|
||||
// NewConfigHandler creates a new handler for configuration management.
|
||||
// It requires the Bifrost client, a logger, and the config store.
|
||||
func NewConfigHandler(configManager ConfigManager, store *lib.Config) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
configManager: configManager,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the configuration-related routes.
|
||||
// It adds the `PUT /api/config` endpoint.
|
||||
func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/api/config", lib.ChainMiddlewares(h.getConfig, middlewares...))
|
||||
r.PUT("/api/config", lib.ChainMiddlewares(h.updateConfig, middlewares...))
|
||||
r.GET("/api/version", lib.ChainMiddlewares(h.getVersion, middlewares...))
|
||||
r.GET("/api/proxy-config", lib.ChainMiddlewares(h.getProxyConfig, middlewares...))
|
||||
r.PUT("/api/proxy-config", lib.ChainMiddlewares(h.updateProxyConfig, middlewares...))
|
||||
r.POST("/api/pricing/force-sync", lib.ChainMiddlewares(h.forceSyncPricing, middlewares...))
|
||||
}
|
||||
|
||||
// getVersion handles GET /api/version - Get the current version
|
||||
func (h *ConfigHandler) getVersion(ctx *fasthttp.RequestCtx) {
|
||||
SendJSON(ctx, version)
|
||||
}
|
||||
|
||||
// getConfig handles GET /config - Get the current configuration
|
||||
func (h *ConfigHandler) getConfig(ctx *fasthttp.RequestCtx) {
|
||||
mapConfig := make(map[string]any)
|
||||
|
||||
if query := string(ctx.QueryArgs().Peek("from_db")); query == "true" {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
|
||||
return
|
||||
}
|
||||
cc, err := h.store.ConfigStore.GetClientConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError,
|
||||
fmt.Sprintf("failed to fetch config from db: %v", err))
|
||||
return
|
||||
}
|
||||
if cc != nil {
|
||||
mapConfig["client_config"] = *cc
|
||||
}
|
||||
// Fetching framework config
|
||||
fc, err := h.store.ConfigStore.GetFrameworkConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to fetch framework config from db: %v", err))
|
||||
return
|
||||
}
|
||||
normalizedFrameworkConfig, _, _ := lib.ResolveFrameworkPricingConfig(fc, nil)
|
||||
mapConfig["framework_config"] = *normalizedFrameworkConfig
|
||||
} else {
|
||||
mapConfig["client_config"] = h.store.ClientConfig
|
||||
normalizedFrameworkConfig, _, _ := lib.ResolveFrameworkPricingConfig(nil, h.store.FrameworkConfig)
|
||||
mapConfig["framework_config"] = *normalizedFrameworkConfig
|
||||
}
|
||||
if h.store.ConfigStore != nil {
|
||||
// Fetching governance config
|
||||
authConfig, err := h.store.ConfigStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get auth config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get auth config from store: %v", err))
|
||||
return
|
||||
}
|
||||
// Getting username and password from auth config
|
||||
// This username password is for the dashboard authentication
|
||||
if authConfig != nil {
|
||||
// For password, return EnvVar structure with redacted value
|
||||
// If from env, preserve env_var reference but clear value
|
||||
// If not from env, show <redacted> as the value
|
||||
var passwordEnvVar *schemas.EnvVar
|
||||
if authConfig.AdminPassword != nil && authConfig.AdminPassword.IsFromEnv() {
|
||||
passwordEnvVar = &schemas.EnvVar{
|
||||
Val: "",
|
||||
EnvVar: authConfig.AdminPassword.EnvVar,
|
||||
FromEnv: true,
|
||||
}
|
||||
} else {
|
||||
passwordEnvVar = &schemas.EnvVar{
|
||||
Val: "<redacted>",
|
||||
EnvVar: "",
|
||||
FromEnv: false,
|
||||
}
|
||||
}
|
||||
mapConfig["auth_config"] = map[string]any{
|
||||
"admin_username": authConfig.AdminUserName,
|
||||
"admin_password": passwordEnvVar,
|
||||
"is_enabled": authConfig.IsEnabled,
|
||||
"disable_auth_on_inference": authConfig.DisableAuthOnInference,
|
||||
}
|
||||
} else {
|
||||
// No auth config exists yet, return default empty EnvVar values
|
||||
mapConfig["auth_config"] = map[string]any{
|
||||
"admin_username": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"admin_password": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"is_enabled": false,
|
||||
"disable_auth_on_inference": false,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mapConfig["auth_config"] = map[string]any{
|
||||
"admin_username": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"admin_password": &schemas.EnvVar{Val: "", EnvVar: "", FromEnv: false},
|
||||
"is_enabled": false,
|
||||
"disable_auth_on_inference": false,
|
||||
}
|
||||
}
|
||||
mapConfig["is_db_connected"] = h.store.ConfigStore != nil
|
||||
mapConfig["is_cache_connected"] = h.store.VectorStore != nil
|
||||
mapConfig["is_logs_connected"] = h.store.LogsStore != nil
|
||||
// Fetching proxy config
|
||||
if h.store.ConfigStore != nil {
|
||||
proxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get proxy config from store: %v", err)
|
||||
} else if proxyConfig != nil {
|
||||
// Redact password if present
|
||||
if proxyConfig.Password != "" {
|
||||
proxyConfig.Password = "<redacted>"
|
||||
}
|
||||
mapConfig["proxy_config"] = proxyConfig
|
||||
}
|
||||
// Fetching restart required config
|
||||
restartConfig, err := h.store.ConfigStore.GetRestartRequiredConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get restart required config from store: %v", err)
|
||||
} else if restartConfig != nil {
|
||||
mapConfig["restart_required"] = restartConfig
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, mapConfig)
|
||||
}
|
||||
|
||||
// updateConfig updates the core configuration settings.
|
||||
// Currently, it supports hot-reloading of the `drop_excess_requests` setting.
|
||||
// Note that settings like `prometheus_labels` cannot be changed at runtime.
|
||||
func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Config store not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
ClientConfig configstore.ClientConfig `json:"client_config"`
|
||||
FrameworkConfig configstoreTables.TableFrameworkConfig `json:"framework_config"`
|
||||
AuthConfig *configstore.AuthConfig `json:"auth_config"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validating framework config
|
||||
if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != modelcatalog.DefaultPricingURL {
|
||||
// Checking the accessibility of the pricing URL
|
||||
resp, err := http.Get(*payload.FrameworkConfig.PricingURL)
|
||||
if err != nil {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", resp.StatusCode)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Checking the pricing sync interval
|
||||
if payload.FrameworkConfig.PricingSyncInterval != nil && *payload.FrameworkConfig.PricingSyncInterval <= 0 {
|
||||
logger.Warn("pricing sync interval must be greater than 0")
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "pricing sync interval must be greater than 0")
|
||||
return
|
||||
}
|
||||
|
||||
// Get current config with proper locking
|
||||
currentConfig := h.store.ClientConfig
|
||||
updatedConfig := currentConfig
|
||||
|
||||
var restartReasons []string
|
||||
|
||||
if payload.ClientConfig.DropExcessRequests != currentConfig.DropExcessRequests {
|
||||
h.configManager.UpdateDropExcessRequests(ctx, payload.ClientConfig.DropExcessRequests)
|
||||
updatedConfig.DropExcessRequests = payload.ClientConfig.DropExcessRequests
|
||||
}
|
||||
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != "" {
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelServer) && payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelTool) {
|
||||
logger.Warn("mcp_code_mode_binding_level must be 'server' or 'tool'")
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "mcp_code_mode_binding_level must be 'server' or 'tool'")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
shouldReloadMCPToolManagerConfig := false
|
||||
|
||||
// Only process MCPAgentDepth if explicitly provided (> 0) and different from current
|
||||
if payload.ClientConfig.MCPAgentDepth > 0 && payload.ClientConfig.MCPAgentDepth != currentConfig.MCPAgentDepth {
|
||||
updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
// Only process MCPToolExecutionTimeout if explicitly provided (> 0) and different from current
|
||||
if payload.ClientConfig.MCPToolExecutionTimeout > 0 && payload.ClientConfig.MCPToolExecutionTimeout != currentConfig.MCPToolExecutionTimeout {
|
||||
updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != "" && payload.ClientConfig.MCPCodeModeBindingLevel != currentConfig.MCPCodeModeBindingLevel {
|
||||
updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
if payload.ClientConfig.MCPDisableAutoToolInject != currentConfig.MCPDisableAutoToolInject {
|
||||
updatedConfig.MCPDisableAutoToolInject = payload.ClientConfig.MCPDisableAutoToolInject
|
||||
shouldReloadMCPToolManagerConfig = true
|
||||
}
|
||||
|
||||
// Reload MCP tool manager config with all current values in one call
|
||||
if shouldReloadMCPToolManagerConfig && h.store.MCPConfig != nil {
|
||||
if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel, updatedConfig.MCPDisableAutoToolInject); err != nil {
|
||||
logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(payload.ClientConfig.PrometheusLabels, currentConfig.PrometheusLabels) {
|
||||
updatedConfig.PrometheusLabels = payload.ClientConfig.PrometheusLabels
|
||||
restartReasons = append(restartReasons, "Prometheus labels")
|
||||
}
|
||||
|
||||
if !slices.Equal(payload.ClientConfig.AllowedOrigins, currentConfig.AllowedOrigins) {
|
||||
updatedConfig.AllowedOrigins = payload.ClientConfig.AllowedOrigins
|
||||
restartReasons = append(restartReasons, "Allowed origins")
|
||||
}
|
||||
|
||||
if !slices.Equal(payload.ClientConfig.AllowedHeaders, currentConfig.AllowedHeaders) {
|
||||
updatedConfig.AllowedHeaders = payload.ClientConfig.AllowedHeaders
|
||||
restartReasons = append(restartReasons, "Allowed headers")
|
||||
}
|
||||
|
||||
// Only update InitialPoolSize if explicitly provided (> 0) to avoid clearing stored value
|
||||
if payload.ClientConfig.InitialPoolSize > 0 {
|
||||
if payload.ClientConfig.InitialPoolSize != currentConfig.InitialPoolSize {
|
||||
restartReasons = append(restartReasons, "Initial pool size")
|
||||
}
|
||||
updatedConfig.InitialPoolSize = payload.ClientConfig.InitialPoolSize
|
||||
}
|
||||
|
||||
if payload.ClientConfig.EnableLogging != nil {
|
||||
payloadLogging := *payload.ClientConfig.EnableLogging
|
||||
currentLogging := currentConfig.EnableLogging == nil || *currentConfig.EnableLogging
|
||||
if payloadLogging != currentLogging {
|
||||
restartReasons = append(restartReasons, "Logging changed")
|
||||
}
|
||||
updatedConfig.EnableLogging = payload.ClientConfig.EnableLogging
|
||||
}
|
||||
|
||||
if payload.ClientConfig.DisableContentLogging != currentConfig.DisableContentLogging {
|
||||
restartReasons = append(restartReasons, "Content logging")
|
||||
}
|
||||
updatedConfig.DisableContentLogging = payload.ClientConfig.DisableContentLogging
|
||||
updatedConfig.DisableDBPingsInHealth = payload.ClientConfig.DisableDBPingsInHealth
|
||||
updatedConfig.AllowDirectKeys = payload.ClientConfig.AllowDirectKeys
|
||||
|
||||
updatedConfig.EnforceAuthOnInference = payload.ClientConfig.EnforceAuthOnInference
|
||||
// Sync deprecated columns to match new field so they stay consistent in the DB
|
||||
updatedConfig.EnforceGovernanceHeader = payload.ClientConfig.EnforceAuthOnInference
|
||||
updatedConfig.EnforceSCIMAuth = payload.ClientConfig.EnforceAuthOnInference
|
||||
|
||||
// Only update MaxRequestBodySizeMB if explicitly provided (> 0) to avoid clearing stored value
|
||||
if payload.ClientConfig.MaxRequestBodySizeMB > 0 {
|
||||
if payload.ClientConfig.MaxRequestBodySizeMB != currentConfig.MaxRequestBodySizeMB {
|
||||
restartReasons = append(restartReasons, "Max request body size")
|
||||
}
|
||||
updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB
|
||||
}
|
||||
|
||||
// Handle compat plugin toggle
|
||||
newCompat := payload.ClientConfig.Compat
|
||||
oldCompat := currentConfig.Compat
|
||||
if newCompat != oldCompat {
|
||||
newEnabled := newCompat.ConvertTextToChat || newCompat.ConvertChatToResponses || newCompat.ShouldDropParams || newCompat.ShouldConvertParams
|
||||
if newEnabled {
|
||||
compatCfg := &compat.Config{
|
||||
ConvertTextToChat: newCompat.ConvertTextToChat,
|
||||
ConvertChatToResponses: newCompat.ConvertChatToResponses,
|
||||
ShouldDropParams: newCompat.ShouldDropParams,
|
||||
ShouldConvertParams: newCompat.ShouldConvertParams,
|
||||
}
|
||||
if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, compatCfg, nil, nil); err != nil {
|
||||
logger.Warn("failed to load compat plugin: %v", err)
|
||||
SendError(ctx, 400, "Failed to load compat plugin")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
disabledCtx := context.WithValue(ctx, PluginDisabledKey, true)
|
||||
if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil {
|
||||
logger.Warn("failed to remove compat plugin: %v", err)
|
||||
SendError(ctx, 400, "Failed to remove compat plugin")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
updatedConfig.Compat = newCompat
|
||||
// Only update MCP fields if explicitly provided (non-zero) to avoid clearing stored values
|
||||
if payload.ClientConfig.MCPAgentDepth > 0 {
|
||||
updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth
|
||||
}
|
||||
if payload.ClientConfig.MCPToolExecutionTimeout > 0 {
|
||||
updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout
|
||||
}
|
||||
// Only update MCPCodeModeBindingLevel if payload is non-empty to avoid clearing stored value
|
||||
if payload.ClientConfig.MCPCodeModeBindingLevel != "" {
|
||||
updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel
|
||||
}
|
||||
|
||||
// Only update AsyncJobResultTTL if explicitly provided (> 0) to avoid clearing stored value
|
||||
if payload.ClientConfig.AsyncJobResultTTL > 0 {
|
||||
updatedConfig.AsyncJobResultTTL = payload.ClientConfig.AsyncJobResultTTL
|
||||
}
|
||||
|
||||
// Handle RequiredHeaders changes (no restart needed - governance plugin reads via pointer)
|
||||
updatedConfig.RequiredHeaders = payload.ClientConfig.RequiredHeaders
|
||||
|
||||
// Handle LoggingHeaders changes (no restart needed - logging plugin reads via pointer)
|
||||
updatedConfig.LoggingHeaders = payload.ClientConfig.LoggingHeaders
|
||||
|
||||
// Handle WhitelistedRoutes changes (updated dynamically via AuthMiddleware)
|
||||
updatedConfig.WhitelistedRoutes = payload.ClientConfig.WhitelistedRoutes
|
||||
|
||||
// Toggle whether deleted virtual keys should appear in logs filter data.
|
||||
updatedConfig.HideDeletedVirtualKeysInFilters = payload.ClientConfig.HideDeletedVirtualKeysInFilters
|
||||
|
||||
// No restart needed - routing engine reads via pointer, change is effective immediately.
|
||||
if payload.ClientConfig.RoutingChainMaxDepth > 0 {
|
||||
updatedConfig.RoutingChainMaxDepth = payload.ClientConfig.RoutingChainMaxDepth
|
||||
}
|
||||
|
||||
// Handle HeaderFilterConfig changes
|
||||
if !headerFilterConfigEqual(payload.ClientConfig.HeaderFilterConfig, currentConfig.HeaderFilterConfig) {
|
||||
// Validate that no security headers are in the allowlist or denylist
|
||||
if err := validateHeaderFilterConfig(payload.ClientConfig.HeaderFilterConfig); err != nil {
|
||||
logger.Warn("invalid header filter config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
updatedConfig.HeaderFilterConfig = payload.ClientConfig.HeaderFilterConfig
|
||||
if err := h.configManager.ReloadHeaderFilterConfig(ctx, payload.ClientConfig.HeaderFilterConfig); err != nil {
|
||||
logger.Warn("failed to reload header filter config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload header filter config: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Validate LogRetentionDays
|
||||
if payload.ClientConfig.LogRetentionDays < 1 {
|
||||
logger.Warn("log_retention_days must be at least 1")
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "log_retention_days must be at least 1")
|
||||
return
|
||||
}
|
||||
updatedConfig.LogRetentionDays = payload.ClientConfig.LogRetentionDays
|
||||
|
||||
// Update the store with the new config
|
||||
h.store.ClientConfig = updatedConfig
|
||||
|
||||
if err := h.store.ConfigStore.UpdateClientConfig(ctx, updatedConfig); err != nil {
|
||||
logger.Warn("failed to save configuration: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save configuration: %v", err))
|
||||
return
|
||||
}
|
||||
// Reloading client config from config store
|
||||
if err := h.configManager.ReloadClientConfigFromConfigStore(ctx); err != nil {
|
||||
logger.Warn("failed to reload client config from config store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload client config from config store: %v", err))
|
||||
return
|
||||
}
|
||||
// Fetching existing framework config
|
||||
frameworkConfig, err := h.store.ConfigStore.GetFrameworkConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get framework config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get framework config from store: %v", err))
|
||||
return
|
||||
}
|
||||
// if framework config is nil, we will use the default pricing config
|
||||
if frameworkConfig == nil {
|
||||
frameworkConfig = &configstoreTables.TableFrameworkConfig{
|
||||
ID: 0,
|
||||
PricingURL: bifrost.Ptr(modelcatalog.DefaultPricingURL),
|
||||
PricingSyncInterval: bifrost.Ptr(int64(modelcatalog.DefaultSyncInterval.Seconds())),
|
||||
}
|
||||
}
|
||||
// Handling individual nil cases
|
||||
if frameworkConfig.PricingURL == nil {
|
||||
frameworkConfig.PricingURL = bifrost.Ptr(modelcatalog.DefaultPricingURL)
|
||||
}
|
||||
if frameworkConfig.PricingSyncInterval == nil {
|
||||
frameworkConfig.PricingSyncInterval = bifrost.Ptr(int64(modelcatalog.DefaultSyncInterval.Seconds()))
|
||||
}
|
||||
// Updating framework config
|
||||
shouldReloadFrameworkConfig := false
|
||||
if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != *frameworkConfig.PricingURL {
|
||||
// Checking the accessibility of the pricing URL
|
||||
resp, err := http.Get(*payload.FrameworkConfig.PricingURL)
|
||||
if err != nil {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("failed to check the accessibility of the pricing URL: %v", resp.StatusCode)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to check the accessibility of the pricing URL: %v", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
frameworkConfig.PricingURL = payload.FrameworkConfig.PricingURL
|
||||
shouldReloadFrameworkConfig = true
|
||||
}
|
||||
if payload.FrameworkConfig.PricingSyncInterval != nil {
|
||||
syncInterval := int64(*payload.FrameworkConfig.PricingSyncInterval)
|
||||
if syncInterval != *frameworkConfig.PricingSyncInterval {
|
||||
frameworkConfig.PricingSyncInterval = &syncInterval
|
||||
shouldReloadFrameworkConfig = true
|
||||
}
|
||||
}
|
||||
// Reload config if required
|
||||
if shouldReloadFrameworkConfig {
|
||||
var syncSeconds int64
|
||||
if frameworkConfig.PricingSyncInterval != nil {
|
||||
syncSeconds = *frameworkConfig.PricingSyncInterval
|
||||
} else {
|
||||
syncSeconds = int64(modelcatalog.DefaultSyncInterval.Seconds())
|
||||
}
|
||||
h.store.FrameworkConfig = &framework.FrameworkConfig{
|
||||
Pricing: &modelcatalog.Config{
|
||||
PricingURL: frameworkConfig.PricingURL,
|
||||
PricingSyncInterval: &syncSeconds,
|
||||
},
|
||||
}
|
||||
// Saving framework config
|
||||
if err := h.store.ConfigStore.UpdateFrameworkConfig(ctx, frameworkConfig); err != nil {
|
||||
logger.Warn("failed to save framework configuration: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save framework configuration: %v", err))
|
||||
return
|
||||
}
|
||||
// Reloading pricing manager
|
||||
h.configManager.UpdateSyncConfig(ctx)
|
||||
}
|
||||
// Checking auth config and trying to update if required
|
||||
if payload.AuthConfig != nil {
|
||||
// Getting current governance config
|
||||
authConfig, err := h.store.ConfigStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, configstore.ErrNotFound) {
|
||||
logger.Warn("failed to get auth config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get auth config from store: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if auth config has changed
|
||||
authChanged := false
|
||||
if authConfig == nil {
|
||||
// No existing config, any enabled state is a change
|
||||
if payload.AuthConfig.IsEnabled {
|
||||
authChanged = true
|
||||
}
|
||||
} else {
|
||||
// Compare with existing config using value comparison (not pointer comparison)
|
||||
// Password is considered changed only if it's NOT redacted and has a value
|
||||
// (IsRedacted() returns true for <redacted>, asterisk patterns, and env var references)
|
||||
passwordChanged := payload.AuthConfig.AdminPassword != nil &&
|
||||
!payload.AuthConfig.AdminPassword.IsRedacted() &&
|
||||
payload.AuthConfig.AdminPassword.GetValue() != ""
|
||||
usernameChanged := payload.AuthConfig.AdminUserName != nil &&
|
||||
!payload.AuthConfig.AdminUserName.Equals(authConfig.AdminUserName)
|
||||
if payload.AuthConfig.IsEnabled != authConfig.IsEnabled ||
|
||||
usernameChanged ||
|
||||
passwordChanged {
|
||||
authChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if payload.AuthConfig.IsEnabled {
|
||||
// Initialize nil pointers to empty EnvVar to prevent nil-pointer dereference
|
||||
if payload.AuthConfig.AdminUserName == nil {
|
||||
payload.AuthConfig.AdminUserName = &schemas.EnvVar{}
|
||||
}
|
||||
if payload.AuthConfig.AdminPassword == nil {
|
||||
payload.AuthConfig.AdminPassword = &schemas.EnvVar{}
|
||||
}
|
||||
|
||||
// Validate env variables are set if referenced
|
||||
if payload.AuthConfig.AdminUserName.IsFromEnv() && payload.AuthConfig.AdminUserName.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("environment variable %s is not set", payload.AuthConfig.AdminUserName.EnvVar))
|
||||
return
|
||||
}
|
||||
if payload.AuthConfig.AdminPassword.IsFromEnv() && payload.AuthConfig.AdminPassword.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("environment variable %s is not set", payload.AuthConfig.AdminPassword.EnvVar))
|
||||
return
|
||||
}
|
||||
|
||||
if authConfig == nil && (payload.AuthConfig.AdminUserName.GetValue() == "" || payload.AuthConfig.AdminPassword.GetValue() == "") {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "auth username and password must be provided")
|
||||
return
|
||||
}
|
||||
// Fetching current Auth config
|
||||
if payload.AuthConfig.AdminUserName.GetValue() != "" {
|
||||
if payload.AuthConfig.AdminPassword.IsRedacted() {
|
||||
if authConfig == nil || authConfig.AdminPassword.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "auth password must be provided")
|
||||
return
|
||||
}
|
||||
// Assuming that password hasn't been changed
|
||||
payload.AuthConfig.AdminPassword = authConfig.AdminPassword
|
||||
} else {
|
||||
// Password has been changed
|
||||
// We will hash the password
|
||||
hashedPassword, err := encrypt.Hash(payload.AuthConfig.AdminPassword.GetValue())
|
||||
if err != nil {
|
||||
logger.Warn("failed to hash password: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to hash password: %v", err))
|
||||
return
|
||||
}
|
||||
// Preserve env-var metadata when storing hashed password
|
||||
payload.AuthConfig.AdminPassword = &schemas.EnvVar{
|
||||
Val: hashedPassword,
|
||||
FromEnv: payload.AuthConfig.AdminPassword.IsFromEnv(),
|
||||
EnvVar: payload.AuthConfig.AdminPassword.EnvVar,
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save auth config - this handles both first-time creation and updates
|
||||
err = h.configManager.UpdateAuthConfig(ctx, payload.AuthConfig)
|
||||
if err != nil {
|
||||
logger.Warn("failed to update auth config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update auth config: %v", err))
|
||||
return
|
||||
}
|
||||
} else if authConfig != nil {
|
||||
// Auth is being disabled but there's an existing config - preserve credentials and update disabled state
|
||||
if payload.AuthConfig.AdminPassword == nil || payload.AuthConfig.AdminPassword.IsRedacted() || payload.AuthConfig.AdminPassword.GetValue() == "" {
|
||||
payload.AuthConfig.AdminPassword = authConfig.AdminPassword
|
||||
}
|
||||
if payload.AuthConfig.AdminUserName == nil || payload.AuthConfig.AdminUserName.GetValue() == "" {
|
||||
payload.AuthConfig.AdminUserName = authConfig.AdminUserName
|
||||
}
|
||||
err = h.configManager.UpdateAuthConfig(ctx, payload.AuthConfig)
|
||||
if err != nil {
|
||||
logger.Warn("failed to update auth config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update auth config: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Flush all existing sessions if auth details have been changed
|
||||
if authChanged {
|
||||
if err := h.store.ConfigStore.FlushSessions(ctx); err != nil {
|
||||
logger.Warn("updated auth config but failed to flush existing sessions, please restart the server: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("updated auth config but failed to flush existing sessions, please restart the server: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Note: AuthMiddleware is updated via ServerCallbacks.UpdateAuthConfig (handled by BifrostHTTPServer)
|
||||
}
|
||||
|
||||
// Set restart required flag if any restart-requiring configs changed
|
||||
if len(restartReasons) > 0 {
|
||||
reason := fmt.Sprintf("%s settings have been updated. A restart is required for changes to take full effect.", strings.Join(restartReasons, ", "))
|
||||
if err := h.store.ConfigStore.SetRestartRequiredConfig(ctx, &configstoreTables.RestartRequiredConfig{
|
||||
Required: true,
|
||||
Reason: reason,
|
||||
}); err != nil {
|
||||
logger.Warn("failed to set restart required config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"status": "success",
|
||||
"message": "configuration updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// forceSyncPricing triggers an immediate pricing sync and resets the pricing sync timer
|
||||
func (h *ConfigHandler) forceSyncPricing(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.configManager.ForceReloadPricing(ctx); err != nil {
|
||||
logger.Warn("failed to force pricing sync: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to force pricing sync: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"status": "success",
|
||||
"message": "pricing sync triggered",
|
||||
})
|
||||
}
|
||||
|
||||
// getProxyConfig handles GET /api/proxy-config - Get the current proxy configuration
|
||||
func (h *ConfigHandler) getProxyConfig(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "config store not available")
|
||||
return
|
||||
}
|
||||
proxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get proxy config: %v", err))
|
||||
return
|
||||
}
|
||||
if proxyConfig == nil {
|
||||
// Return default empty config
|
||||
SendJSON(ctx, configstoreTables.GlobalProxyConfig{
|
||||
Enabled: false,
|
||||
Type: network.GlobalProxyTypeHTTP,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Redact password if present
|
||||
if proxyConfig.Password != "" {
|
||||
proxyConfig.Password = "<redacted>"
|
||||
}
|
||||
SendJSON(ctx, proxyConfig)
|
||||
}
|
||||
|
||||
// updateProxyConfig handles PUT /api/proxy-config - Update the proxy configuration
|
||||
func (h *ConfigHandler) updateProxyConfig(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "config store not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
var payload configstoreTables.GlobalProxyConfig
|
||||
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate proxy config
|
||||
if payload.Enabled {
|
||||
// Validate proxy type
|
||||
switch payload.Type {
|
||||
case network.GlobalProxyTypeHTTP:
|
||||
// HTTP proxy is supported
|
||||
// Make sure the URL is provided
|
||||
if payload.URL == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy URL is required when proxy is enabled")
|
||||
return
|
||||
}
|
||||
// Validate timeout if provided
|
||||
if payload.Timeout < 0 {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy timeout must be non-negative")
|
||||
return
|
||||
}
|
||||
case network.GlobalProxyTypeSOCKS5, network.GlobalProxyTypeTCP:
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("proxy type %s is not yet supported", payload.Type))
|
||||
return
|
||||
default:
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid proxy type: %s", payload.Type))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate URL is provided when enabled
|
||||
if payload.URL == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy URL is required when proxy is enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate timeout if provided
|
||||
if payload.Timeout < 0 {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "proxy timeout must be non-negative")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Handle password - if it's "<redacted>", keep the existing password
|
||||
if payload.Password == "<redacted>" {
|
||||
existingConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil && !errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get existing proxy config: %v", err))
|
||||
return
|
||||
}
|
||||
if existingConfig != nil {
|
||||
payload.Password = existingConfig.Password
|
||||
} else {
|
||||
payload.Password = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Save proxy config
|
||||
if err := h.store.ConfigStore.UpdateProxyConfig(ctx, &payload); err != nil {
|
||||
logger.Warn("failed to save proxy configuration: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to save proxy configuration: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Pulling the proxy config from the config store
|
||||
newProxyConfig, err := h.store.ConfigStore.GetProxyConfig(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get proxy config from store: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get proxy config from store: %v", err))
|
||||
return
|
||||
}
|
||||
if newProxyConfig == nil {
|
||||
newProxyConfig = &configstoreTables.GlobalProxyConfig{
|
||||
Enabled: false,
|
||||
Type: network.GlobalProxyTypeHTTP,
|
||||
URL: "",
|
||||
Username: "",
|
||||
Password: "",
|
||||
NoProxy: "",
|
||||
Timeout: 0,
|
||||
SkipTLSVerify: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Reload proxy config in the server
|
||||
if err := h.configManager.ReloadProxyConfig(ctx, newProxyConfig); err != nil {
|
||||
logger.Warn("failed to reload proxy config: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to reload proxy config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set restart required flag for proxy config changes
|
||||
if err := h.store.ConfigStore.SetRestartRequiredConfig(ctx, &configstoreTables.RestartRequiredConfig{
|
||||
Required: true,
|
||||
Reason: "Proxy configuration has been updated. A restart is required for all changes to take full effect.",
|
||||
}); err != nil {
|
||||
logger.Warn("failed to set restart required config: %v", err)
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"status": "success",
|
||||
"message": "proxy configuration updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// headerFilterConfigEqual compares two GlobalHeaderFilterConfig for equality
|
||||
func headerFilterConfigEqual(a, b *configstoreTables.GlobalHeaderFilterConfig) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return slices.Equal(a.Allowlist, b.Allowlist) && slices.Equal(a.Denylist, b.Denylist)
|
||||
}
|
||||
|
||||
// validateHeaderFilterConfig validates that no exact security header names are in the allowlist or denylist
|
||||
// and that wildcard patterns use valid syntax (only trailing * is supported).
|
||||
// Wildcard patterns that would match security headers are allowed because security headers
|
||||
// are unconditionally stripped at runtime regardless of configuration.
|
||||
// Returns an error if any exact security headers are found or patterns are invalid.
|
||||
func validateHeaderFilterConfig(config *configstoreTables.GlobalHeaderFilterConfig) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate pattern syntax and normalize entries (trim, lowercase, drop empties)
|
||||
filteredAllow := config.Allowlist[:0]
|
||||
for _, header := range config.Allowlist {
|
||||
h := strings.ToLower(strings.TrimSpace(header))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(h, "*"); idx != -1 && idx != len(h)-1 {
|
||||
return fmt.Errorf("invalid pattern %q: wildcard (*) is only supported at the end of a pattern", h)
|
||||
}
|
||||
filteredAllow = append(filteredAllow, h)
|
||||
}
|
||||
config.Allowlist = filteredAllow
|
||||
filteredDeny := config.Denylist[:0]
|
||||
for _, header := range config.Denylist {
|
||||
h := strings.ToLower(strings.TrimSpace(header))
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(h, "*"); idx != -1 && idx != len(h)-1 {
|
||||
return fmt.Errorf("invalid pattern %q: wildcard (*) is only supported at the end of a pattern", h)
|
||||
}
|
||||
filteredDeny = append(filteredDeny, h)
|
||||
}
|
||||
config.Denylist = filteredDeny
|
||||
|
||||
var foundSecurityHeaders []string
|
||||
|
||||
// Check allowlist for exact security header names.
|
||||
// Wildcard patterns are allowed — security headers are always stripped at runtime
|
||||
// unconditionally in ctx.go, regardless of allowlist/denylist configuration.
|
||||
for _, header := range config.Allowlist {
|
||||
headerLower := strings.ToLower(strings.TrimSpace(header))
|
||||
if strings.Contains(headerLower, "*") {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(securityHeaders, headerLower) {
|
||||
foundSecurityHeaders = append(foundSecurityHeaders, headerLower)
|
||||
}
|
||||
}
|
||||
|
||||
// Check denylist for exact security header names.
|
||||
for _, header := range config.Denylist {
|
||||
headerLower := strings.ToLower(strings.TrimSpace(header))
|
||||
if strings.Contains(headerLower, "*") {
|
||||
continue
|
||||
}
|
||||
if slices.Contains(securityHeaders, headerLower) && !slices.Contains(foundSecurityHeaders, headerLower) {
|
||||
foundSecurityHeaders = append(foundSecurityHeaders, headerLower)
|
||||
}
|
||||
}
|
||||
|
||||
if len(foundSecurityHeaders) > 0 {
|
||||
return fmt.Errorf("the following headers are not allowed to be configured: %s. These headers are security headers and are always blocked", strings.Join(foundSecurityHeaders, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
198
transports/bifrost-http/handlers/config_headerfilter_test.go
Normal file
198
transports/bifrost-http/handlers/config_headerfilter_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
func TestValidateHeaderFilterConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *configstoreTables.GlobalHeaderFilterConfig
|
||||
wantErr bool
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "empty lists",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{},
|
||||
},
|
||||
{
|
||||
name: "empty allowlist and denylist slices",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{},
|
||||
Denylist: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid allowlist patterns",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-beta", "x-custom-*", "content-type"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid denylist patterns",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-internal-*", "x-debug"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid allowlist and denylist together",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*", "content-type"},
|
||||
Denylist: []string{"x-internal-*"},
|
||||
},
|
||||
},
|
||||
// Empty/whitespace entries should be silently dropped, not cause errors
|
||||
{
|
||||
name: "whitespace-only entries in allowlist are dropped",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{" ", "anthropic-beta", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "whitespace-only entries in denylist are dropped",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"", "x-debug", " "},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all-empty allowlist becomes effectively empty",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"", " ", "\t"},
|
||||
},
|
||||
},
|
||||
// Security header checks
|
||||
{
|
||||
name: "security header in allowlist rejected",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"authorization"},
|
||||
},
|
||||
wantErr: true,
|
||||
errSubstr: "not allowed to be configured",
|
||||
},
|
||||
{
|
||||
name: "security header in denylist rejected",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-api-key"},
|
||||
},
|
||||
wantErr: true,
|
||||
errSubstr: "not allowed to be configured",
|
||||
},
|
||||
{
|
||||
name: "wildcard matching security header allowed (runtime strips security headers)",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"authorization*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "wildcard prefix matching security headers allowed (runtime strips security headers)",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"x-api-*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bare wildcard in allowlist allowed (runtime strips security headers)",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
},
|
||||
},
|
||||
// Invalid wildcard syntax
|
||||
{
|
||||
name: "wildcard in middle of pattern rejected",
|
||||
config: &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"x-*-header"},
|
||||
},
|
||||
wantErr: true,
|
||||
errSubstr: "wildcard (*) is only supported at the end",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateHeaderFilterConfig(tt.config)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.errSubstr)
|
||||
}
|
||||
if tt.errSubstr != "" && !contains(err.Error(), tt.errSubstr) {
|
||||
t.Fatalf("expected error containing %q, got %q", tt.errSubstr, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHeaderFilterConfig_EmptyEntriesDropped(t *testing.T) {
|
||||
// Verify that empty/whitespace entries are actually removed from the stored config
|
||||
config := &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{" ", "anthropic-beta", "", "content-type", "\t"},
|
||||
Denylist: []string{"", "x-debug", " "},
|
||||
}
|
||||
if err := validateHeaderFilterConfig(config); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(config.Allowlist) != 2 {
|
||||
t.Fatalf("expected allowlist length 2, got %d: %v", len(config.Allowlist), config.Allowlist)
|
||||
}
|
||||
if config.Allowlist[0] != "anthropic-beta" || config.Allowlist[1] != "content-type" {
|
||||
t.Fatalf("unexpected allowlist: %v", config.Allowlist)
|
||||
}
|
||||
if len(config.Denylist) != 1 {
|
||||
t.Fatalf("expected denylist length 1, got %d: %v", len(config.Denylist), config.Denylist)
|
||||
}
|
||||
if config.Denylist[0] != "x-debug" {
|
||||
t.Fatalf("unexpected denylist: %v", config.Denylist)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateHeaderFilterConfig_EmptyConfigStillForwardsHeaders verifies that when
|
||||
// all entries are empty/whitespace, validation strips them and the compiled matcher
|
||||
// allows all headers through (same behavior as no config — x-bf-eh-* headers forwarded as-is).
|
||||
func TestValidateHeaderFilterConfig_EmptyConfigStillForwardsHeaders(t *testing.T) {
|
||||
// Config where all entries are whitespace-only
|
||||
config := &configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"", " ", "\t"},
|
||||
Denylist: []string{"", " "},
|
||||
}
|
||||
if err := validateHeaderFilterConfig(config); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// After validation, both lists should be empty
|
||||
if len(config.Allowlist) != 0 {
|
||||
t.Fatalf("expected empty allowlist, got %v", config.Allowlist)
|
||||
}
|
||||
if len(config.Denylist) != 0 {
|
||||
t.Fatalf("expected empty denylist, got %v", config.Denylist)
|
||||
}
|
||||
// Compile the validated config into a matcher — should allow everything
|
||||
m := lib.NewHeaderMatcher(config)
|
||||
// Matcher with empty lists should allow all headers (x-bf-eh-* forwarded as-is)
|
||||
for _, header := range []string{"anthropic-beta", "x-custom-header", "content-type", "x-anything"} {
|
||||
if !m.ShouldAllow(header) {
|
||||
t.Errorf("expected header %q to be allowed with empty config, but it was denied", header)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
778
transports/bifrost-http/handlers/devpprof.go
Normal file
778
transports/bifrost-http/handlers/devpprof.go
Normal file
@@ -0,0 +1,778 @@
|
||||
//go:build dev
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/pprof/profile"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// Collection interval for metrics
|
||||
metricsCollectionInterval = 10 * time.Second
|
||||
// Number of data points to keep (5 minutes / 10 seconds = 30 points)
|
||||
historySize = 30
|
||||
// Top allocations to return per table (cumulative and in-use)
|
||||
topAllocationsCount = 50
|
||||
)
|
||||
|
||||
// MemoryStats represents memory statistics at a point in time
|
||||
type MemoryStats struct {
|
||||
Alloc uint64 `json:"alloc"`
|
||||
TotalAlloc uint64 `json:"total_alloc"`
|
||||
HeapInuse uint64 `json:"heap_inuse"`
|
||||
HeapObjects uint64 `json:"heap_objects"`
|
||||
Sys uint64 `json:"sys"`
|
||||
}
|
||||
|
||||
// CPUStats represents CPU statistics
|
||||
type CPUStats struct {
|
||||
UsagePercent float64 `json:"usage_percent"`
|
||||
UserTime float64 `json:"user_time"`
|
||||
SystemTime float64 `json:"system_time"`
|
||||
}
|
||||
|
||||
// RuntimeStats represents runtime statistics
|
||||
type RuntimeStats struct {
|
||||
NumGoroutine int `json:"num_goroutine"`
|
||||
NumGC uint32 `json:"num_gc"`
|
||||
GCPauseNs uint64 `json:"gc_pause_ns"`
|
||||
NumCPU int `json:"num_cpu"`
|
||||
GOMAXPROCS int `json:"gomaxprocs"`
|
||||
}
|
||||
|
||||
// AllocationInfo represents a single allocation site
|
||||
type AllocationInfo struct {
|
||||
Function string `json:"function"`
|
||||
File string `json:"file"`
|
||||
Line int `json:"line"`
|
||||
Bytes int64 `json:"bytes"`
|
||||
Count int64 `json:"count"`
|
||||
Stack []string `json:"stack"`
|
||||
}
|
||||
|
||||
// GoroutineGroup represents a group of goroutines with the same stack trace
|
||||
type GoroutineGroup struct {
|
||||
Count int `json:"count"`
|
||||
State string `json:"state"`
|
||||
WaitReason string `json:"wait_reason,omitempty"`
|
||||
WaitMinutes int `json:"wait_minutes,omitempty"` // Parsed wait time in minutes
|
||||
TopFunc string `json:"top_func"`
|
||||
Stack []string `json:"stack"`
|
||||
Category string `json:"category"` // "background", "per-request", "unknown"
|
||||
}
|
||||
|
||||
// GoroutineProfile represents the goroutine profile response
|
||||
type GoroutineProfile struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
TotalGoroutines int `json:"total_goroutines"`
|
||||
Groups []GoroutineGroup `json:"groups"`
|
||||
Summary GoroutineSummary `json:"summary"`
|
||||
RawProfile string `json:"raw_profile,omitempty"`
|
||||
}
|
||||
|
||||
// GoroutineSummary provides a quick overview of goroutine health
|
||||
type GoroutineSummary struct {
|
||||
Background int `json:"background"` // Expected long-running goroutines
|
||||
PerRequest int `json:"per_request"` // Goroutines that should complete with requests
|
||||
LongWaiting int `json:"long_waiting"` // Goroutines waiting > 1 minute (potential leaks)
|
||||
PotentiallyStuck int `json:"potentially_stuck"` // Per-request goroutines waiting > 1 minute
|
||||
}
|
||||
|
||||
// HistoryPoint represents a single point in the metrics history
|
||||
type HistoryPoint struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
Alloc uint64 `json:"alloc"`
|
||||
HeapInuse uint64 `json:"heap_inuse"`
|
||||
Goroutines int `json:"goroutines"`
|
||||
GCPauseNs uint64 `json:"gc_pause_ns"`
|
||||
CPUPercent float64 `json:"cpu_percent"`
|
||||
}
|
||||
|
||||
// PprofData represents the complete pprof response
|
||||
type PprofData struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
Memory MemoryStats `json:"memory"`
|
||||
CPU CPUStats `json:"cpu"`
|
||||
Runtime RuntimeStats `json:"runtime"`
|
||||
TopAllocations []AllocationInfo `json:"top_allocations"`
|
||||
InuseAllocations []AllocationInfo `json:"inuse_allocations"`
|
||||
History []HistoryPoint `json:"history"`
|
||||
}
|
||||
|
||||
// cpuSample holds a CPU time sample for calculating usage
|
||||
type cpuSample struct {
|
||||
timestamp time.Time
|
||||
userTime time.Duration
|
||||
systemTime time.Duration
|
||||
}
|
||||
|
||||
// MetricsCollector collects and stores runtime metrics
|
||||
type MetricsCollector struct {
|
||||
mu sync.RWMutex
|
||||
history []HistoryPoint
|
||||
stopCh chan struct{}
|
||||
started bool
|
||||
lastCPUSample cpuSample
|
||||
currentCPU CPUStats
|
||||
}
|
||||
|
||||
// DevPprofHandler handles development profiling endpoints
|
||||
type DevPprofHandler struct {
|
||||
collector *MetricsCollector
|
||||
}
|
||||
|
||||
// Global collector instance
|
||||
var globalCollector *MetricsCollector
|
||||
var collectorOnce sync.Once
|
||||
|
||||
// IsDevMode checks if dev mode is enabled via environment variable
|
||||
func IsDevMode() bool {
|
||||
return os.Getenv("BIFROST_UI_DEV") == "true"
|
||||
}
|
||||
|
||||
// getOrCreateCollector returns the global metrics collector, creating it if needed
|
||||
func getOrCreateCollector() *MetricsCollector {
|
||||
collectorOnce.Do(func() {
|
||||
globalCollector = &MetricsCollector{
|
||||
history: make([]HistoryPoint, 0, historySize),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
})
|
||||
return globalCollector
|
||||
}
|
||||
|
||||
// NewDevPprofHandler creates a new dev pprof handler
|
||||
func NewDevPprofHandler() *DevPprofHandler {
|
||||
return &DevPprofHandler{
|
||||
collector: getOrCreateCollector(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the background metrics collection
|
||||
func (c *MetricsCollector) Start() {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.stopCh = make(chan struct{})
|
||||
c.started = true
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.collectLoop()
|
||||
}
|
||||
|
||||
// Stop stops the background metrics collection
|
||||
func (c *MetricsCollector) Stop() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if !c.started {
|
||||
return
|
||||
}
|
||||
close(c.stopCh)
|
||||
c.stopCh = nil
|
||||
c.started = false
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) collectLoop() {
|
||||
// Initialize CPU sample
|
||||
c.lastCPUSample = getCPUSample()
|
||||
|
||||
// Wait a bit before first collection to get accurate CPU reading
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Collect immediately on start
|
||||
c.collect()
|
||||
|
||||
ticker := time.NewTicker(metricsCollectionInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.collect()
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateCPUUsage calculates CPU usage percentage between two samples
|
||||
func calculateCPUUsage(prev, curr cpuSample, numCPU int) CPUStats {
|
||||
elapsed := curr.timestamp.Sub(prev.timestamp)
|
||||
if elapsed <= 0 {
|
||||
return CPUStats{}
|
||||
}
|
||||
|
||||
userDelta := curr.userTime - prev.userTime
|
||||
systemDelta := curr.systemTime - prev.systemTime
|
||||
totalCPUTime := userDelta + systemDelta
|
||||
|
||||
// Calculate percentage: (CPU time used / wall time) * 100
|
||||
// Normalized by number of CPUs to get 0-100% range
|
||||
cpuPercent := (float64(totalCPUTime) / float64(elapsed)) * 100.0
|
||||
|
||||
// Cap at 100% * numCPU (in case of measurement errors)
|
||||
maxPercent := float64(numCPU) * 100.0
|
||||
if cpuPercent > maxPercent {
|
||||
cpuPercent = maxPercent
|
||||
}
|
||||
|
||||
return CPUStats{
|
||||
UsagePercent: cpuPercent,
|
||||
UserTime: userDelta.Seconds(),
|
||||
SystemTime: systemDelta.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) collect() {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
// Get current CPU sample and calculate usage
|
||||
currentSample := getCPUSample()
|
||||
cpuStats := calculateCPUUsage(c.lastCPUSample, currentSample, runtime.NumCPU())
|
||||
c.lastCPUSample = currentSample
|
||||
|
||||
point := HistoryPoint{
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Alloc: memStats.Alloc,
|
||||
HeapInuse: memStats.HeapInuse,
|
||||
Goroutines: runtime.NumGoroutine(),
|
||||
GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256],
|
||||
CPUPercent: cpuStats.UsagePercent,
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Store current CPU stats for API response
|
||||
c.currentCPU = cpuStats
|
||||
|
||||
// Append to history, maintaining ring buffer behavior
|
||||
if len(c.history) >= historySize {
|
||||
// Shift left by one and append
|
||||
copy(c.history, c.history[1:])
|
||||
c.history[len(c.history)-1] = point
|
||||
} else {
|
||||
c.history = append(c.history, point)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) getHistory() []HistoryPoint {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
result := make([]HistoryPoint, len(c.history))
|
||||
copy(result, c.history)
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *MetricsCollector) getCPUStats() CPUStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.currentCPU
|
||||
}
|
||||
|
||||
// getAllocations analyzes the heap profile and returns two allocation lists
|
||||
// aggregated by full call stack:
|
||||
// - cumulative: alloc_space / alloc_objects (total since process start)
|
||||
// - inuse: inuse_space / inuse_objects (currently live on the heap)
|
||||
//
|
||||
// Both are produced from a single pprof.WriteHeapProfile call.
|
||||
func getAllocations() (cumulative, inuse []AllocationInfo) {
|
||||
var buf bytes.Buffer
|
||||
if err := pprof.WriteHeapProfile(&buf); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
p, err := profile.Parse(&buf)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
allocObjectsIdx, allocSpaceIdx := -1, -1
|
||||
inuseObjectsIdx, inuseSpaceIdx := -1, -1
|
||||
for i, st := range p.SampleType {
|
||||
switch st.Type {
|
||||
case "alloc_objects":
|
||||
allocObjectsIdx = i
|
||||
case "alloc_space":
|
||||
allocSpaceIdx = i
|
||||
case "inuse_objects":
|
||||
inuseObjectsIdx = i
|
||||
case "inuse_space":
|
||||
inuseSpaceIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
allocMap := make(map[string]*AllocationInfo)
|
||||
inuseMap := make(map[string]*AllocationInfo)
|
||||
|
||||
for _, sample := range p.Sample {
|
||||
if len(sample.Location) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
topLoc := sample.Location[0]
|
||||
if len(topLoc.Line) == 0 {
|
||||
continue
|
||||
}
|
||||
topLine := topLoc.Line[0]
|
||||
topFn := topLine.Function
|
||||
if topFn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter only the top frame — filtering inner frames would drop real
|
||||
// user allocations that merely pass through runtime/profiler code.
|
||||
if isProfilerFunction(topFn.Name, topFn.Filename) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build full stack in goroutine-dump format: alternating "funcName" and
|
||||
// "\tfile:line" entries, top-down. Matches GoroutineGroup.Stack so the
|
||||
// UI can render both with the same code path.
|
||||
stack := make([]string, 0, len(sample.Location)*2)
|
||||
for _, loc := range sample.Location {
|
||||
if len(loc.Line) == 0 {
|
||||
continue
|
||||
}
|
||||
frame := loc.Line[0]
|
||||
if frame.Function == nil {
|
||||
continue
|
||||
}
|
||||
stack = append(stack, frame.Function.Name)
|
||||
stack = append(stack, "\t"+frame.Function.Filename+":"+strconv.FormatInt(frame.Line, 10))
|
||||
}
|
||||
if len(stack) == 0 {
|
||||
continue
|
||||
}
|
||||
key := strings.Join(stack, "\n")
|
||||
|
||||
if allocSpaceIdx >= 0 && allocObjectsIdx >= 0 {
|
||||
b := sample.Value[allocSpaceIdx]
|
||||
c := sample.Value[allocObjectsIdx]
|
||||
if existing, ok := allocMap[key]; ok {
|
||||
existing.Bytes += b
|
||||
existing.Count += c
|
||||
} else {
|
||||
allocMap[key] = &AllocationInfo{
|
||||
Function: topFn.Name,
|
||||
File: topFn.Filename,
|
||||
Line: int(topLine.Line),
|
||||
Bytes: b,
|
||||
Count: c,
|
||||
Stack: stack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if inuseSpaceIdx >= 0 && inuseObjectsIdx >= 0 {
|
||||
b := sample.Value[inuseSpaceIdx]
|
||||
c := sample.Value[inuseObjectsIdx]
|
||||
// Most samples have inuse=0 (already freed) — skip them so the live
|
||||
// table isn't padded with noise.
|
||||
if b == 0 && c == 0 {
|
||||
continue
|
||||
}
|
||||
if existing, ok := inuseMap[key]; ok {
|
||||
existing.Bytes += b
|
||||
existing.Count += c
|
||||
} else {
|
||||
inuseMap[key] = &AllocationInfo{
|
||||
Function: topFn.Name,
|
||||
File: topFn.Filename,
|
||||
Line: int(topLine.Line),
|
||||
Bytes: b,
|
||||
Count: c,
|
||||
Stack: stack,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return flattenAndTopN(allocMap), flattenAndTopN(inuseMap)
|
||||
}
|
||||
|
||||
// flattenAndTopN sorts an allocation map by bytes desc and caps it.
|
||||
func flattenAndTopN(m map[string]*AllocationInfo) []AllocationInfo {
|
||||
out := make([]AllocationInfo, 0, len(m))
|
||||
for _, a := range m {
|
||||
out = append(out, *a)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].Bytes > out[j].Bytes })
|
||||
if len(out) > topAllocationsCount {
|
||||
out = out[:topAllocationsCount]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the dev pprof routes
|
||||
func (h *DevPprofHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// Start the collector when routes are registered
|
||||
h.collector.Start()
|
||||
|
||||
r.GET("/api/dev/pprof", lib.ChainMiddlewares(h.getPprof, middlewares...))
|
||||
r.GET("/api/dev/pprof/goroutines", lib.ChainMiddlewares(h.getGoroutines, middlewares...))
|
||||
}
|
||||
|
||||
// getPprof handles GET /api/dev/pprof
|
||||
func (h *DevPprofHandler) getPprof(ctx *fasthttp.RequestCtx) {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
data := PprofData{
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
Memory: MemoryStats{
|
||||
Alloc: memStats.Alloc,
|
||||
TotalAlloc: memStats.TotalAlloc,
|
||||
HeapInuse: memStats.HeapInuse,
|
||||
HeapObjects: memStats.HeapObjects,
|
||||
Sys: memStats.Sys,
|
||||
},
|
||||
CPU: h.collector.getCPUStats(),
|
||||
Runtime: RuntimeStats{
|
||||
NumGoroutine: runtime.NumGoroutine(),
|
||||
NumGC: memStats.NumGC,
|
||||
GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256],
|
||||
NumCPU: runtime.NumCPU(),
|
||||
GOMAXPROCS: runtime.GOMAXPROCS(0),
|
||||
},
|
||||
History: h.collector.getHistory(),
|
||||
}
|
||||
data.TopAllocations, data.InuseAllocations = getAllocations()
|
||||
|
||||
SendJSON(ctx, data)
|
||||
}
|
||||
|
||||
// getGoroutines handles GET /api/dev/pprof/goroutines
|
||||
// Returns goroutine stack traces grouped by stack signature
|
||||
func (h *DevPprofHandler) getGoroutines(ctx *fasthttp.RequestCtx) {
|
||||
// Check if raw output is requested
|
||||
includeRaw := string(ctx.QueryArgs().Peek("raw")) == "true"
|
||||
|
||||
// Get goroutine profile
|
||||
var buf bytes.Buffer
|
||||
if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
SendJSON(ctx, map[string]string{"error": "failed to get goroutine profile"})
|
||||
return
|
||||
}
|
||||
|
||||
rawProfile := buf.String()
|
||||
allGroups := parseGoroutineProfile(rawProfile)
|
||||
|
||||
// Filter out profiler goroutines and calculate summary
|
||||
groups := make([]GoroutineGroup, 0, len(allGroups))
|
||||
summary := GoroutineSummary{}
|
||||
profilerGoroutineCount := 0
|
||||
|
||||
for i := range allGroups {
|
||||
categorizeGoroutine(&allGroups[i])
|
||||
|
||||
// Skip profiler's own goroutines
|
||||
if isProfilerGoroutine(&allGroups[i]) {
|
||||
profilerGoroutineCount += allGroups[i].Count
|
||||
continue
|
||||
}
|
||||
|
||||
groups = append(groups, allGroups[i])
|
||||
|
||||
switch allGroups[i].Category {
|
||||
case "background":
|
||||
summary.Background += allGroups[i].Count
|
||||
case "per-request":
|
||||
summary.PerRequest += allGroups[i].Count
|
||||
}
|
||||
|
||||
if allGroups[i].WaitMinutes >= 1 {
|
||||
summary.LongWaiting += allGroups[i].Count
|
||||
if allGroups[i].Category == "per-request" {
|
||||
summary.PotentiallyStuck += allGroups[i].Count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort: potentially stuck first, then by wait time, then by count
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
// Potentially stuck (per-request + long wait) first
|
||||
iStuck := groups[i].Category == "per-request" && groups[i].WaitMinutes >= 1
|
||||
jStuck := groups[j].Category == "per-request" && groups[j].WaitMinutes >= 1
|
||||
if iStuck != jStuck {
|
||||
return iStuck
|
||||
}
|
||||
// Then by wait time
|
||||
if groups[i].WaitMinutes != groups[j].WaitMinutes {
|
||||
return groups[i].WaitMinutes > groups[j].WaitMinutes
|
||||
}
|
||||
// Then by count
|
||||
return groups[i].Count > groups[j].Count
|
||||
})
|
||||
|
||||
// Calculate app goroutines (total minus profiler goroutines)
|
||||
// Calculate total goroutines from profile snapshot
|
||||
totalFromProfile := 0
|
||||
for _, g := range groups {
|
||||
totalFromProfile += g.Count
|
||||
}
|
||||
|
||||
response := GoroutineProfile{
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
TotalGoroutines: totalFromProfile,
|
||||
Groups: groups,
|
||||
Summary: summary,
|
||||
}
|
||||
|
||||
if includeRaw {
|
||||
response.RawProfile = rawProfile
|
||||
}
|
||||
|
||||
SendJSON(ctx, response)
|
||||
}
|
||||
|
||||
// categorizeGoroutine determines if a goroutine is a background worker or per-request
|
||||
func categorizeGoroutine(g *GoroutineGroup) {
|
||||
// Parse wait time from wait reason (e.g., "5 minutes" -> 5)
|
||||
g.WaitMinutes = parseWaitMinutes(g.WaitReason)
|
||||
|
||||
stackStr := strings.Join(g.Stack, " ")
|
||||
|
||||
// Background goroutines - expected to run forever
|
||||
backgroundPatterns := []string{
|
||||
"requestWorker", // Provider queue workers
|
||||
"collectLoop", // Metrics collector
|
||||
"cleanupWorker", // Various cleanup workers
|
||||
"startAccumulatorMapCleanup", // Stream accumulator cleanup
|
||||
"cleanupOldTraces", // Trace store cleanup
|
||||
"startCleanup", // Generic cleanup
|
||||
"monitorLoop", // Health monitor
|
||||
"StartHeartbeat", // WebSocket heartbeat
|
||||
"time.Sleep", // Ticker-based workers
|
||||
"runtime.gopark", // Runtime parking (often tickers)
|
||||
"sync.(*Cond).Wait", // Condition variable waits
|
||||
"net/http.(*persistConn)", // HTTP connection pool
|
||||
"internal/poll.runtime_pollWait", // Network polling
|
||||
}
|
||||
|
||||
for _, pattern := range backgroundPatterns {
|
||||
if strings.Contains(stackStr, pattern) {
|
||||
g.Category = "background"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Per-request goroutines - should complete when request ends
|
||||
perRequestPatterns := []string{
|
||||
"PreLLMHook",
|
||||
"PostLLMHook",
|
||||
"PreMCPHook",
|
||||
"PostMCPHook",
|
||||
"HTTPTransportPreHook",
|
||||
"HTTPTransportPostHook",
|
||||
"CompleteAndFlushTrace",
|
||||
"ProcessAndSend",
|
||||
"handleProvider",
|
||||
"Inject", // Observability plugin inject
|
||||
"insertInitialLogEntry", // Logging
|
||||
"updateLogEntry", // Logging
|
||||
"retryOnNotFound",
|
||||
"BroadcastLogUpdate",
|
||||
}
|
||||
|
||||
for _, pattern := range perRequestPatterns {
|
||||
if strings.Contains(stackStr, pattern) {
|
||||
g.Category = "per-request"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
g.Category = "unknown"
|
||||
}
|
||||
|
||||
// parseWaitMinutes extracts wait time in minutes from wait reason string
|
||||
func parseWaitMinutes(waitReason string) int {
|
||||
if waitReason == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Match patterns like "5 minutes", "1 minute", "30 seconds", "2 hours"
|
||||
minuteRegex := regexp.MustCompile(`(\d+)\s*minute`)
|
||||
if matches := minuteRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
|
||||
if mins, err := strconv.Atoi(matches[1]); err == nil {
|
||||
return mins
|
||||
}
|
||||
}
|
||||
|
||||
hourRegex := regexp.MustCompile(`(\d+)\s*hour`)
|
||||
if matches := hourRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
|
||||
if hours, err := strconv.Atoi(matches[1]); err == nil {
|
||||
return hours * 60
|
||||
}
|
||||
}
|
||||
|
||||
secondRegex := regexp.MustCompile(`(\d+)\s*second`)
|
||||
if matches := secondRegex.FindStringSubmatch(waitReason); len(matches) >= 2 {
|
||||
if secs, err := strconv.Atoi(matches[1]); err == nil {
|
||||
return secs / 60 // Convert to minutes, will be 0 for < 60 seconds
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseGoroutineProfile parses the text output of pprof goroutine profile
|
||||
// and groups goroutines by their stack trace
|
||||
func parseGoroutineProfile(profile string) []GoroutineGroup {
|
||||
// Regex to match goroutine header: "goroutine N [state, wait reason]:"
|
||||
// Examples:
|
||||
// goroutine 1 [running]:
|
||||
// goroutine 42 [select, 5 minutes]:
|
||||
// goroutine 100 [chan receive]:
|
||||
headerRegex := regexp.MustCompile(`goroutine \d+ \[([^\]]+)\]:`)
|
||||
|
||||
// Split by "goroutine " to get individual goroutine blocks
|
||||
blocks := strings.Split(profile, "goroutine ")
|
||||
|
||||
// Map to group goroutines by stack signature
|
||||
groupMap := make(map[string]*GoroutineGroup)
|
||||
|
||||
for _, block := range blocks {
|
||||
block = strings.TrimSpace(block)
|
||||
if block == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Re-add "goroutine " prefix for regex matching
|
||||
fullBlock := "goroutine " + block
|
||||
|
||||
// Extract state from header
|
||||
matches := headerRegex.FindStringSubmatch(fullBlock)
|
||||
if len(matches) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
stateInfo := matches[1]
|
||||
state := stateInfo
|
||||
waitReason := ""
|
||||
|
||||
// Parse state and wait reason (e.g., "select, 5 minutes" -> state="select", waitReason="5 minutes")
|
||||
if idx := strings.Index(stateInfo, ","); idx != -1 {
|
||||
state = strings.TrimSpace(stateInfo[:idx])
|
||||
waitReason = strings.TrimSpace(stateInfo[idx+1:])
|
||||
}
|
||||
|
||||
// Get stack trace (everything after the header line)
|
||||
lines := strings.Split(block, "\n")
|
||||
if len(lines) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract stack frames (skip the header line which is lines[0])
|
||||
var stackLines []string
|
||||
var topFunc string
|
||||
for i := 1; i < len(lines); i++ {
|
||||
line := strings.TrimSpace(lines[i])
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
stackLines = append(stackLines, line)
|
||||
|
||||
// First function line (not a file:line) is the top function
|
||||
if topFunc == "" && !strings.HasPrefix(line, "/") && !strings.Contains(line, ".go:") {
|
||||
topFunc = line
|
||||
}
|
||||
}
|
||||
|
||||
if len(stackLines) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a signature from the stack (top 10 frames for grouping)
|
||||
maxFrames := 10
|
||||
if len(stackLines) < maxFrames {
|
||||
maxFrames = len(stackLines)
|
||||
}
|
||||
signature := state + "|" + strings.Join(stackLines[:maxFrames], "|")
|
||||
|
||||
// Group by signature
|
||||
if existing, ok := groupMap[signature]; ok {
|
||||
existing.Count++
|
||||
} else {
|
||||
groupMap[signature] = &GoroutineGroup{
|
||||
Count: 1,
|
||||
State: state,
|
||||
WaitReason: waitReason,
|
||||
TopFunc: topFunc,
|
||||
Stack: stackLines,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
groups := make([]GoroutineGroup, 0, len(groupMap))
|
||||
for _, group := range groupMap {
|
||||
groups = append(groups, *group)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// profilerPatterns contains patterns to identify profiler-related code
|
||||
var profilerPatterns = []string{
|
||||
"devpprof",
|
||||
"pprof.WriteHeapProfile",
|
||||
"pprof.Lookup",
|
||||
"profile.Parse",
|
||||
"MetricsCollector",
|
||||
"collectLoop",
|
||||
"getAllocations",
|
||||
"flattenAndTopN",
|
||||
"parseGoroutineProfile",
|
||||
"getGoroutines",
|
||||
"getCPUSample",
|
||||
}
|
||||
|
||||
// isProfilerFunction checks if a function belongs to the profiler itself
|
||||
func isProfilerFunction(funcName, fileName string) bool {
|
||||
for _, pattern := range profilerPatterns {
|
||||
if strings.Contains(funcName, pattern) || strings.Contains(fileName, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isProfilerGoroutine checks if a goroutine belongs to the profiler
|
||||
func isProfilerGoroutine(g *GoroutineGroup) bool {
|
||||
stackStr := strings.Join(g.Stack, " ")
|
||||
for _, pattern := range profilerPatterns {
|
||||
if strings.Contains(stackStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Cleanup stops the metrics collector
|
||||
func (h *DevPprofHandler) Cleanup() {
|
||||
if h.collector != nil {
|
||||
h.collector.Stop()
|
||||
}
|
||||
}
|
||||
23
transports/bifrost-http/handlers/devpprof_prod.go
Normal file
23
transports/bifrost-http/handlers/devpprof_prod.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build !dev
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// DevPprofHandler is a no-op stub for production builds (built without the "dev" tag).
|
||||
type DevPprofHandler struct{}
|
||||
|
||||
// IsDevMode always returns false in production builds.
|
||||
func IsDevMode() bool { return false }
|
||||
|
||||
// NewDevPprofHandler returns nil in production builds.
|
||||
func NewDevPprofHandler() *DevPprofHandler { return nil }
|
||||
|
||||
// RegisterRoutes is a no-op in production builds.
|
||||
func (h *DevPprofHandler) RegisterRoutes(_ *router.Router, _ ...schemas.BifrostHTTPMiddleware) {}
|
||||
|
||||
// Cleanup is a no-op in production builds.
|
||||
func (h *DevPprofHandler) Cleanup() {}
|
||||
26
transports/bifrost-http/handlers/devpprof_unix.go
Normal file
26
transports/bifrost-http/handlers/devpprof_unix.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build dev && !windows
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getCPUSample gets the current CPU time sample using syscall
|
||||
func getCPUSample() cpuSample {
|
||||
var rusage syscall.Rusage
|
||||
if err := syscall.Getrusage(syscall.RUSAGE_SELF, &rusage); err != nil {
|
||||
return cpuSample{timestamp: time.Now()}
|
||||
}
|
||||
|
||||
userTime := time.Duration(rusage.Utime.Sec)*time.Second + time.Duration(rusage.Utime.Usec)*time.Microsecond
|
||||
systemTime := time.Duration(rusage.Stime.Sec)*time.Second + time.Duration(rusage.Stime.Usec)*time.Microsecond
|
||||
|
||||
return cpuSample{
|
||||
timestamp: time.Now(),
|
||||
userTime: userTime,
|
||||
systemTime: systemTime,
|
||||
}
|
||||
}
|
||||
|
||||
12
transports/bifrost-http/handlers/devpprof_windows.go
Normal file
12
transports/bifrost-http/handlers/devpprof_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build dev && windows
|
||||
|
||||
package handlers
|
||||
|
||||
import "time"
|
||||
|
||||
// getCPUSample returns a zeroed CPU sample on Windows
|
||||
// Windows does not support syscall.Getrusage
|
||||
func getCPUSample() cpuSample {
|
||||
return cpuSample{timestamp: time.Now()}
|
||||
}
|
||||
|
||||
3850
transports/bifrost-http/handlers/governance.go
Normal file
3850
transports/bifrost-http/handlers/governance.go
Normal file
File diff suppressed because it is too large
Load Diff
337
transports/bifrost-http/handlers/governance_test.go
Normal file
337
transports/bifrost-http/handlers/governance_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockGovernanceManagerForVK embeds the interface so unimplemented methods panic.
|
||||
// Only GetGovernanceData is needed for the getVirtualKeys handler path.
|
||||
type mockGovernanceManagerForVK struct {
|
||||
GovernanceManager
|
||||
}
|
||||
|
||||
func (m *mockGovernanceManagerForVK) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockConfigStoreForVK embeds the interface so unimplemented methods panic.
|
||||
// Only GetVirtualKeysPaginated is called in the non-from_memory path.
|
||||
type mockConfigStoreForVK struct {
|
||||
configstore.ConfigStore
|
||||
}
|
||||
|
||||
func (m *mockConfigStoreForVK) GetVirtualKeysPaginated(_ context.Context, _ configstore.VirtualKeyQueryParams) ([]configstoreTables.TableVirtualKey, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConfigStoreForVK) GetVirtualKeys(_ context.Context) ([]configstoreTables.TableVirtualKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestGetVirtualKeys_PaginatedEndpoint_ResponseShape verifies the JSON response
|
||||
// from the paginated virtual keys endpoint contains all expected fields.
|
||||
func TestGetVirtualKeys_PaginatedEndpoint_ResponseShape(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := &GovernanceHandler{
|
||||
configStore: &mockConfigStoreForVK{},
|
||||
governanceManager: &mockGovernanceManagerForVK{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/governance/virtual-keys?limit=10&offset=0")
|
||||
|
||||
h.getVirtualKeys(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Assert expected fields exist with correct types
|
||||
requiredFields := []struct {
|
||||
key string
|
||||
wantType string
|
||||
}{
|
||||
{"virtual_keys", "array"},
|
||||
{"total_count", "number"},
|
||||
{"count", "number"},
|
||||
{"limit", "number"},
|
||||
{"offset", "number"},
|
||||
}
|
||||
|
||||
for _, f := range requiredFields {
|
||||
val, ok := resp[f.key]
|
||||
if !ok {
|
||||
t.Errorf("response missing required field %q", f.key)
|
||||
continue
|
||||
}
|
||||
switch f.wantType {
|
||||
case "array":
|
||||
if _, ok := val.([]interface{}); !ok {
|
||||
// nil decodes as nil, which is fine — JSON null for empty array
|
||||
if val != nil {
|
||||
t.Errorf("field %q: expected array, got %T", f.key, val)
|
||||
}
|
||||
}
|
||||
case "number":
|
||||
if _, ok := val.(float64); !ok {
|
||||
t.Errorf("field %q: expected number, got %T", f.key, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected extra top-level fields
|
||||
allowedKeys := map[string]bool{
|
||||
"virtual_keys": true,
|
||||
"total_count": true,
|
||||
"count": true,
|
||||
"limit": true,
|
||||
"offset": true,
|
||||
}
|
||||
for key := range resp {
|
||||
if !allowedKeys[key] {
|
||||
t.Errorf("unexpected field %q in response", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetVirtualKeys_PaginatedEndpoint_QueryParams verifies query parameters are
|
||||
// parsed and reflected in the response.
|
||||
func TestGetVirtualKeys_PaginatedEndpoint_QueryParams(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := &GovernanceHandler{
|
||||
configStore: &mockConfigStoreForVK{},
|
||||
governanceManager: &mockGovernanceManagerForVK{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
wantLimit float64
|
||||
wantOffset float64
|
||||
}{
|
||||
{
|
||||
name: "explicit limit and offset",
|
||||
uri: "/api/governance/virtual-keys?limit=10&offset=5",
|
||||
wantLimit: 10,
|
||||
wantOffset: 5,
|
||||
},
|
||||
{
|
||||
name: "no params uses defaults",
|
||||
uri: "/api/governance/virtual-keys",
|
||||
wantLimit: 0,
|
||||
wantOffset: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI(tt.uri)
|
||||
|
||||
h.getVirtualKeys(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
if got := resp["limit"].(float64); got != tt.wantLimit {
|
||||
t.Errorf("limit: got %v, want %v", got, tt.wantLimit)
|
||||
}
|
||||
if got := resp["offset"].(float64); got != tt.wantOffset {
|
||||
t.Errorf("offset: got %v, want %v", got, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure mockLogger satisfies schemas.Logger (already defined in middlewares_test.go
|
||||
// but we reference it here — same package, so no redeclaration needed).
|
||||
var _ schemas.Logger = (*mockLogger)(nil)
|
||||
|
||||
func TestBudgetRemovalRequestDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *UpdateBudgetRequest
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil request is not removal",
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty object is removal",
|
||||
req: &UpdateBudgetRequest{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "max limit present is not removal",
|
||||
req: &UpdateBudgetRequest{MaxLimit: bifrostFloat(10)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "reset duration only is not removal",
|
||||
req: &UpdateBudgetRequest{ResetDuration: bifrostString("1h")},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "calendar aligned only is treated as removal",
|
||||
req: &UpdateBudgetRequest{CalendarAligned: bifrostBool(true)},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isBudgetRemovalRequest(tt.req); got != tt.want {
|
||||
t.Fatalf("isBudgetRemovalRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitRemovalRequestDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *UpdateRateLimitRequest
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil request is not removal",
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty object is removal",
|
||||
req: &UpdateRateLimitRequest{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "token limit present is not removal",
|
||||
req: &UpdateRateLimitRequest{TokenMaxLimit: bifrostInt64(100)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "request limit present is not removal",
|
||||
req: &UpdateRateLimitRequest{RequestMaxLimit: bifrostInt64(10)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "durations only is not removal",
|
||||
req: &UpdateRateLimitRequest{
|
||||
TokenResetDuration: bifrostString("1h"),
|
||||
RequestResetDuration: bifrostString("1h"),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isRateLimitRemovalRequest(tt.req); got != tt.want {
|
||||
t.Fatalf("isRateLimitRemovalRequest() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectProviderConfigDeleteIDs(t *testing.T) {
|
||||
budgetID := "budget-1"
|
||||
rateLimitID := "rate-limit-1"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config configstoreTables.TableVirtualKeyProviderConfig
|
||||
initialBudgetIDs []string
|
||||
initialRateIDs []string
|
||||
wantBudgetIDs []string
|
||||
wantRateIDs []string
|
||||
}{
|
||||
{
|
||||
name: "collects both IDs",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
|
||||
RateLimitID: &rateLimitID,
|
||||
},
|
||||
wantBudgetIDs: []string{budgetID},
|
||||
wantRateIDs: []string{rateLimitID},
|
||||
},
|
||||
{
|
||||
name: "appends to existing slices",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{
|
||||
Budgets: []configstoreTables.TableBudget{{ID: budgetID}},
|
||||
RateLimitID: &rateLimitID,
|
||||
},
|
||||
initialBudgetIDs: []string{"budget-0"},
|
||||
initialRateIDs: []string{"rate-limit-0"},
|
||||
wantBudgetIDs: []string{"budget-0", budgetID},
|
||||
wantRateIDs: []string{"rate-limit-0", rateLimitID},
|
||||
},
|
||||
{
|
||||
name: "ignores missing IDs",
|
||||
config: configstoreTables.TableVirtualKeyProviderConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBudgetIDs, gotRateIDs := collectProviderConfigDeleteIDs(tt.config, tt.initialBudgetIDs, tt.initialRateIDs)
|
||||
|
||||
if len(gotBudgetIDs) != len(tt.wantBudgetIDs) {
|
||||
t.Fatalf("budget IDs length = %d, want %d", len(gotBudgetIDs), len(tt.wantBudgetIDs))
|
||||
}
|
||||
for i := range gotBudgetIDs {
|
||||
if gotBudgetIDs[i] != tt.wantBudgetIDs[i] {
|
||||
t.Fatalf("budget IDs[%d] = %q, want %q", i, gotBudgetIDs[i], tt.wantBudgetIDs[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(gotRateIDs) != len(tt.wantRateIDs) {
|
||||
t.Fatalf("rate limit IDs length = %d, want %d", len(gotRateIDs), len(tt.wantRateIDs))
|
||||
}
|
||||
for i := range gotRateIDs {
|
||||
if gotRateIDs[i] != tt.wantRateIDs[i] {
|
||||
t.Fatalf("rate limit IDs[%d] = %q, want %q", i, gotRateIDs[i], tt.wantRateIDs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func bifrostFloat(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostInt64(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostString(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func bifrostBool(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
90
transports/bifrost-http/handlers/health.go
Normal file
90
transports/bifrost-http/handlers/health.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// HealthHandler manages HTTP requests for health checks.
|
||||
type HealthHandler struct {
|
||||
config *lib.Config
|
||||
}
|
||||
|
||||
// NewHealthHandler creates a new health handler instance.
|
||||
func NewHealthHandler(config *lib.Config) *HealthHandler {
|
||||
return &HealthHandler{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the health-related routes.
|
||||
func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/health", lib.ChainMiddlewares(h.getHealth, middlewares...))
|
||||
}
|
||||
|
||||
// getHealth handles GET /api/health - Get the health status of the server.
|
||||
func (h *HealthHandler) getHealth(ctx *fasthttp.RequestCtx) {
|
||||
// If DB pings are disabled, just return OK
|
||||
if h.config.ClientConfig.DisableDBPingsInHealth {
|
||||
SendJSON(ctx, map[string]any{"status": "ok", "components": map[string]any{"db_pings": "disabled"}})
|
||||
return
|
||||
}
|
||||
// Pinging config store
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
var errors []string
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
if h.config.ConfigStore != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.config.ConfigStore.Ping(reqCtx); err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, "config store not available")
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Pinging log store
|
||||
if h.config.LogsStore != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.config.LogsStore.Ping(reqCtx); err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, "log store not available")
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Pinging vector store
|
||||
if h.config.VectorStore != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := h.config.VectorStore.Ping(reqCtx); err != nil {
|
||||
mu.Lock()
|
||||
errors = append(errors, "vector store not available")
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errors) > 0 {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, errors[0])
|
||||
return
|
||||
}
|
||||
SendJSON(ctx, map[string]any{"status": "ok", "components": map[string]any{"db_pings": "ok"}})
|
||||
}
|
||||
3807
transports/bifrost-http/handlers/inference.go
Normal file
3807
transports/bifrost-http/handlers/inference.go
Normal file
File diff suppressed because it is too large
Load Diff
20
transports/bifrost-http/handlers/init.go
Normal file
20
transports/bifrost-http/handlers/init.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package handlers
|
||||
|
||||
import "github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
var version string
|
||||
var logger schemas.Logger
|
||||
|
||||
// SetLogger sets the logger for the application.
|
||||
func SetLogger(l schemas.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
// SetVersion sets the version for the application.
|
||||
func SetVersion(v string) {
|
||||
version = v
|
||||
}
|
||||
|
||||
func GetVersion() string {
|
||||
return version
|
||||
}
|
||||
111
transports/bifrost-http/handlers/integrations.go
Normal file
111
transports/bifrost-http/handlers/integrations.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains integration management handlers for AI provider integrations.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// IntegrationHandler manages HTTP requests for AI provider integrations
|
||||
type IntegrationHandler struct {
|
||||
extensions []integrations.ExtensionRouter
|
||||
wsResponses *WSResponsesHandler
|
||||
wsRealtime *WSRealtimeHandler
|
||||
webrtcRealtime *WebRTCRealtimeHandler
|
||||
realtimeClientSecrets *RealtimeClientSecretsHandler
|
||||
}
|
||||
|
||||
// NewIntegrationHandler creates a new integration handler instance.
|
||||
// WebSocket handlers may be nil if WebSocket support is not configured.
|
||||
func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler, wsRealtime *WSRealtimeHandler, webrtcRealtime *WebRTCRealtimeHandler, realtimeClientSecrets *RealtimeClientSecretsHandler) *IntegrationHandler {
|
||||
// Initialize all available integration routers
|
||||
extensions := []integrations.ExtensionRouter{
|
||||
integrations.NewOpenAIRouter(client, handlerStore, logger),
|
||||
integrations.NewAnthropicRouter(client, handlerStore, logger),
|
||||
integrations.NewGenAIRouter(client, handlerStore, logger),
|
||||
integrations.NewLiteLLMRouter(client, handlerStore, logger),
|
||||
integrations.NewCohereRouter(client, handlerStore, logger),
|
||||
integrations.NewLangChainRouter(client, handlerStore, logger),
|
||||
integrations.NewPydanticAIRouter(client, handlerStore, logger),
|
||||
integrations.NewBedrockRouter(client, handlerStore, logger),
|
||||
// passthrough routers
|
||||
integrations.NewGenAIPassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewOpenAIPassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewAnthropicPassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewAzurePassthroughRouter(client, handlerStore, logger),
|
||||
integrations.NewCursorRouter(client, handlerStore, logger),
|
||||
}
|
||||
|
||||
return &IntegrationHandler{
|
||||
extensions: extensions,
|
||||
wsResponses: wsResponses,
|
||||
wsRealtime: wsRealtime,
|
||||
webrtcRealtime: webrtcRealtime,
|
||||
realtimeClientSecrets: realtimeClientSecrets,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all integration routes for AI provider compatibility endpoints
|
||||
func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// Register routes for each integration extension
|
||||
for _, extension := range h.extensions {
|
||||
extension.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
// Register WebSocket routes (base path + integration paths)
|
||||
if h.wsResponses != nil {
|
||||
h.wsResponses.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
if h.wsRealtime != nil {
|
||||
h.wsRealtime.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
if h.webrtcRealtime != nil {
|
||||
h.webrtcRealtime.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
if h.realtimeClientSecrets != nil {
|
||||
h.realtimeClientSecrets.RegisterRoutes(r, middlewares...)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *IntegrationHandler) Close() {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
if h.wsResponses != nil {
|
||||
h.wsResponses.Close()
|
||||
}
|
||||
if h.wsRealtime != nil {
|
||||
h.wsRealtime.Close()
|
||||
}
|
||||
if h.webrtcRealtime != nil {
|
||||
h.webrtcRealtime.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// SetLargePayloadHook sets the large payload detection hook on all integration routers
|
||||
// that support it. This is used by enterprise to inject large payload optimization.
|
||||
func (h *IntegrationHandler) SetLargePayloadHook(hook integrations.LargePayloadHook) {
|
||||
for _, extension := range h.extensions {
|
||||
if setter, ok := extension.(interface {
|
||||
SetLargePayloadHook(integrations.LargePayloadHook)
|
||||
}); ok {
|
||||
setter.SetLargePayloadHook(hook)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLargeResponseHook sets the large response scanning hook on all integration routers
|
||||
// that support it. Enterprise uses this to inject Phase B usage extraction into the
|
||||
// response stream without embedding scanning logic in the OSS router.
|
||||
func (h *IntegrationHandler) SetLargeResponseHook(hook integrations.LargeResponseHook) {
|
||||
for _, extension := range h.extensions {
|
||||
if setter, ok := extension.(interface {
|
||||
SetLargeResponseHook(integrations.LargeResponseHook)
|
||||
}); ok {
|
||||
setter.SetLargeResponseHook(hook)
|
||||
}
|
||||
}
|
||||
}
|
||||
1653
transports/bifrost-http/handlers/logging.go
Normal file
1653
transports/bifrost-http/handlers/logging.go
Normal file
File diff suppressed because it is too large
Load Diff
1164
transports/bifrost-http/handlers/mcp.go
Normal file
1164
transports/bifrost-http/handlers/mcp.go
Normal file
File diff suppressed because it is too large
Load Diff
112
transports/bifrost-http/handlers/mcpinference.go
Normal file
112
transports/bifrost-http/handlers/mcpinference.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type MCPInferenceHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
}
|
||||
|
||||
// NewMCPInferenceHandler creates a new MCP inference handler instance
|
||||
func NewMCPInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *MCPInferenceHandler {
|
||||
return &MCPInferenceHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the MCP inference routes
|
||||
func (h *MCPInferenceHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.POST("/v1/mcp/tool/execute", lib.ChainMiddlewares(h.executeTool, middlewares...))
|
||||
}
|
||||
|
||||
// executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool
|
||||
func (h *MCPInferenceHandler) executeTool(ctx *fasthttp.RequestCtx) {
|
||||
// Check format query parameter
|
||||
format := strings.ToLower(string(ctx.QueryArgs().Peek("format")))
|
||||
switch format {
|
||||
case "chat", "":
|
||||
h.executeChatMCPTool(ctx)
|
||||
case "responses":
|
||||
h.executeResponsesMCPTool(ctx)
|
||||
default:
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid format value, must be 'chat' or 'responses'")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// executeChatMCPTool handles POST /v1/mcp/tool/execute?format=chat - Execute MCP tool
|
||||
func (h *MCPInferenceHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) {
|
||||
var req schemas.ChatAssistantMessageToolCall
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Function.Name == nil || *req.Function.Name == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel() // Ensure cleanup on function exit
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
|
||||
// Execute MCP tool
|
||||
toolMessage, bifrostErr := h.client.ExecuteChatMCPTool(bifrostCtx, &req)
|
||||
if bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Send successful response
|
||||
SendJSON(ctx, toolMessage)
|
||||
}
|
||||
|
||||
// executeResponsesMCPTool handles POST /v1/mcp/tool/execute?format=responses - Execute MCP tool
|
||||
func (h *MCPInferenceHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) {
|
||||
var req schemas.ResponsesToolMessage
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Name == nil || *req.Name == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel() // Ensure cleanup on function exit
|
||||
if bifrostCtx == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
|
||||
return
|
||||
}
|
||||
|
||||
// Execute MCP tool
|
||||
toolMessage, bifrostErr := h.client.ExecuteResponsesMCPTool(bifrostCtx, &req)
|
||||
if bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Send successful response
|
||||
SendJSON(ctx, toolMessage)
|
||||
}
|
||||
568
transports/bifrost-http/handlers/mcpserver.go
Normal file
568
transports/bifrost-http/handlers/mcpserver.go
Normal file
@@ -0,0 +1,568 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// MCPToolExecutor interface defines the method needed for executing MCP tools
|
||||
type MCPToolManager interface {
|
||||
GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool
|
||||
ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError)
|
||||
ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError)
|
||||
}
|
||||
|
||||
// MCPServerHandler manages HTTP requests for MCP server operations
|
||||
// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients
|
||||
type MCPServerHandler struct {
|
||||
toolManager MCPToolManager
|
||||
globalMCPServer *server.MCPServer
|
||||
vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server
|
||||
config *lib.Config
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMCPServerHandler creates a new MCP server handler instance
|
||||
func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if toolManager == nil {
|
||||
return nil, fmt.Errorf("tool manager is required")
|
||||
}
|
||||
|
||||
// Create MCP server instance using mcp-go
|
||||
globalMCPServer := server.NewMCPServer(
|
||||
"global",
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
handler := &MCPServerHandler{
|
||||
toolManager: toolManager,
|
||||
globalMCPServer: globalMCPServer,
|
||||
config: config,
|
||||
vkMCPServers: make(map[string]*server.MCPServer),
|
||||
}
|
||||
|
||||
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
|
||||
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
|
||||
|
||||
// Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list
|
||||
server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer)
|
||||
|
||||
if err := handler.SyncAllMCPServers(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to sync all MCP servers: %w", err)
|
||||
}
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the MCP server route
|
||||
func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE)
|
||||
r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...))
|
||||
r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...))
|
||||
}
|
||||
|
||||
// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages
|
||||
// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth
|
||||
// session exists, injects the session token and identity (VK / User ID) directly
|
||||
// into the BifrostContext. This avoids header-based identity propagation which
|
||||
// would be vulnerable to spoofing by upstream callers.
|
||||
//
|
||||
// Governance context keys are set here intentionally (bypassing governance plugin)
|
||||
// because in the MCP gateway path, identity is pre-authenticated via the OAuth session.
|
||||
func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true)
|
||||
if session != nil {
|
||||
if session.AccessToken != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken)
|
||||
}
|
||||
if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value)
|
||||
}
|
||||
if session.UserID != nil && *session.UserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) {
|
||||
mcpServer, session, err := h.getMCPServerForRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
defer cancel()
|
||||
|
||||
injectMCPSessionIdentity(bifrostCtx, session)
|
||||
|
||||
// Use mcp-go server to handle the request
|
||||
// HandleMessage processes JSON-RPC messages and returns appropriate responses
|
||||
response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody())
|
||||
|
||||
// Check if response is nil (notification - no response needed)
|
||||
if response == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal and send response
|
||||
responseJSON, err := sonic.Marshal(response)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetBody(responseJSON)
|
||||
}
|
||||
|
||||
// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming
|
||||
func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) {
|
||||
_, session, err := h.getMCPServerForRequest(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
ctx.SetContentType("text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
ctx.Response.Header.Set("Connection", "keep-alive")
|
||||
|
||||
// Convert context
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
|
||||
|
||||
injectMCPSessionIdentity(bifrostCtx, session)
|
||||
|
||||
// Use SSEStreamReader to bypass fasthttp's internal pipe batching
|
||||
reader := lib.NewSSEStreamReader()
|
||||
ctx.Response.SetBodyStream(reader, -1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
cancel()
|
||||
reader.Done()
|
||||
}()
|
||||
|
||||
// Send initial connection message
|
||||
initMessage := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "connection/opened",
|
||||
}
|
||||
if initJSON, err := sonic.Marshal(initMessage); err == nil {
|
||||
buf := make([]byte, 0, len(initJSON)+8)
|
||||
buf = append(buf, "data: "...)
|
||||
buf = append(buf, initJSON...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
reader.Send(buf)
|
||||
}
|
||||
|
||||
// Wait for context cancellation (client disconnect or server-side cancel)
|
||||
<-(*bifrostCtx).Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// Sync methods for MCP servers
|
||||
|
||||
func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
availableTools := h.toolManager.GetAvailableMCPTools(ctx)
|
||||
h.syncServer(h.globalMCPServer, availableTools, nil)
|
||||
logger.Debug("Synced global MCP server with %d tools", len(availableTools))
|
||||
|
||||
// initialize vkMCPServers map
|
||||
if h.config.ConfigStore != nil {
|
||||
virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get virtual keys: %w", err)
|
||||
}
|
||||
h.vkMCPServers = make(map[string]*server.MCPServer)
|
||||
for i := range virtualKeys {
|
||||
vk := &virtualKeys[i]
|
||||
vkServer := server.NewMCPServer(
|
||||
vk.Name,
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
availableTools, toolFilter := h.fetchToolsForVK(vk)
|
||||
h.syncServer(h.vkMCPServers[vk.Value], availableTools, toolFilter)
|
||||
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
vkServer, ok := h.vkMCPServers[vk.Value]
|
||||
if !ok {
|
||||
// Add new server
|
||||
vkServer = server.NewMCPServer(
|
||||
vk.Name,
|
||||
version,
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
server.WithToolFilter(h.makeIncludeClientsFilter())(vkServer)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
}
|
||||
availableTools, toolFilter := h.fetchToolsForVK(vk)
|
||||
h.syncServer(vkServer, availableTools, toolFilter)
|
||||
h.vkMCPServers[vk.Value] = vkServer
|
||||
logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools))
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.vkMCPServers, vkValue)
|
||||
}
|
||||
|
||||
func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool, toolFilter []string) {
|
||||
// Clear existing tools
|
||||
toolMap := server.ListTools()
|
||||
for toolName, _ := range toolMap {
|
||||
server.DeleteTools(toolName)
|
||||
}
|
||||
|
||||
// Register tools from all connected clients
|
||||
for _, tool := range availableTools {
|
||||
// Only process function tools (skip custom tools)
|
||||
if tool.Function == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture tool name for closure
|
||||
toolName := tool.Function.Name
|
||||
|
||||
handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Inject tool filter into execution context if present
|
||||
if toolFilter != nil {
|
||||
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter)
|
||||
}
|
||||
// Convert to Bifrost tool call format
|
||||
toolCallType := "function"
|
||||
toolCallID := fmt.Sprintf("mcp-%s", toolName)
|
||||
argsJSON, jsonErr := sonic.Marshal(request.GetArguments())
|
||||
if jsonErr != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil
|
||||
}
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
ID: &toolCallID,
|
||||
Type: &toolCallType,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &toolName,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
|
||||
// Execute the tool via tool executor
|
||||
toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall)
|
||||
if err != nil {
|
||||
if err.ExtraFields.MCPAuthRequired != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf(
|
||||
"Authentication required for %s. Open this URL to connect your account: %s",
|
||||
err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL,
|
||||
)), nil
|
||||
}
|
||||
return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil
|
||||
}
|
||||
|
||||
// Extract content from tool message
|
||||
var resultText string
|
||||
if toolMessage != nil && toolMessage.Content != nil {
|
||||
// Handle ContentStr (string content)
|
||||
if toolMessage.Content.ContentStr != nil {
|
||||
resultText = *toolMessage.Content.ContentStr
|
||||
} else if toolMessage.Content.ContentBlocks != nil {
|
||||
// Handle ContentBlocks (structured content)
|
||||
for _, block := range toolMessage.Content.ContentBlocks {
|
||||
if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil {
|
||||
resultText += *block.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return result using mcp-go helper
|
||||
return mcp.NewToolResultText(resultText), nil
|
||||
}
|
||||
|
||||
// Convert description from *string to string
|
||||
description := ""
|
||||
if tool.Function.Description != nil {
|
||||
description = *tool.Function.Description
|
||||
}
|
||||
|
||||
// Convert Parameters to mcp.ToolInputSchema
|
||||
var inputSchema mcp.ToolInputSchema
|
||||
if tool.Function.Parameters != nil {
|
||||
inputSchema.Type = tool.Function.Parameters.Type
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
// Convert *map[string]interface{} to map[string]any
|
||||
props := make(map[string]any)
|
||||
tool.Function.Parameters.Properties.Range(func(key string, value interface{}) bool {
|
||||
props[key] = value
|
||||
return true
|
||||
})
|
||||
inputSchema.Properties = props
|
||||
}
|
||||
if tool.Function.Parameters.Required != nil {
|
||||
inputSchema.Required = tool.Function.Parameters.Required
|
||||
}
|
||||
} else {
|
||||
// Default to empty object schema if no parameters
|
||||
inputSchema.Type = "object"
|
||||
inputSchema.Properties = make(map[string]any)
|
||||
}
|
||||
|
||||
// Map Bifrost annotations back to MCP tool annotations
|
||||
var toolAnnotation mcp.ToolAnnotation
|
||||
if tool.Annotations != nil {
|
||||
toolAnnotation = mcp.ToolAnnotation{
|
||||
Title: tool.Annotations.Title,
|
||||
ReadOnlyHint: tool.Annotations.ReadOnlyHint,
|
||||
DestructiveHint: tool.Annotations.DestructiveHint,
|
||||
IdempotentHint: tool.Annotations.IdempotentHint,
|
||||
OpenWorldHint: tool.Annotations.OpenWorldHint,
|
||||
}
|
||||
}
|
||||
|
||||
// Register tool with the server
|
||||
server.AddTool(mcp.Tool{
|
||||
Name: toolName,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
Annotations: toolAnnotation,
|
||||
}, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchToolsForVK fetches the tools for a given virtual key value.
|
||||
// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server.
|
||||
// Returns the list of available tools and the tool filter to be applied during execution.
|
||||
func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schemas.ChatTool, []string) {
|
||||
ctx := context.Background()
|
||||
var toolFilter []string
|
||||
|
||||
executeOnlyTools := make([]string, 0)
|
||||
|
||||
// Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName.
|
||||
// Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys.
|
||||
allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients()
|
||||
if allowAllVKsClients == nil {
|
||||
allowAllVKsClients = make(map[string]string)
|
||||
}
|
||||
|
||||
// Process explicit VK MCPConfigs first.
|
||||
handledClients := make(map[string]bool)
|
||||
for _, vkMcpConfig := range vk.MCPConfigs {
|
||||
clientID := vkMcpConfig.MCPClient.ClientID
|
||||
if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll {
|
||||
// Explicit config exists — it takes precedence; mark handled regardless of tool list.
|
||||
handledClients[clientID] = true
|
||||
}
|
||||
if vkMcpConfig.ToolsToExecute.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if vkMcpConfig.ToolsToExecute.IsUnrestricted() {
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name))
|
||||
continue
|
||||
}
|
||||
for _, tool := range vkMcpConfig.ToolsToExecute {
|
||||
if tool != "" {
|
||||
// Add the tool - client config filtering will be handled by mcp.go
|
||||
// Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*")
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools.
|
||||
for clientID, clientName := range allowAllVKsClients {
|
||||
if !handledClients[clientID] {
|
||||
executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName))
|
||||
}
|
||||
}
|
||||
|
||||
// Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients)
|
||||
ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools)
|
||||
toolFilter = executeOnlyTools
|
||||
|
||||
return h.toolManager.GetAvailableMCPTools(ctx), toolFilter
|
||||
}
|
||||
|
||||
// makeIncludeClientsFilter returns a ToolFilterFunc that dynamically filters the tools/list
|
||||
// response based on the x-bf-mcp-include-clients and x-bf-mcp-include-tools request headers.
|
||||
// When neither header is present the filter is a no-op, preserving existing behaviour.
|
||||
func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc {
|
||||
return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
|
||||
if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil {
|
||||
return tools
|
||||
}
|
||||
allowed := h.toolManager.GetAvailableMCPTools(ctx)
|
||||
allowedNames := make(map[string]bool, len(allowed))
|
||||
for _, t := range allowed {
|
||||
if t.Function != nil {
|
||||
allowedNames[t.Function.Name] = true
|
||||
}
|
||||
}
|
||||
result := make([]mcp.Tool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
if allowedNames[tool.Name] {
|
||||
result = append(result, tool)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
h.config.Mu.RLock()
|
||||
enforceVK := h.config.ClientConfig.EnforceAuthOnInference
|
||||
h.config.Mu.RUnlock()
|
||||
|
||||
vk := getVKFromRequest(ctx)
|
||||
|
||||
// Check for Bifrost per-user OAuth Bearer token (not a VK)
|
||||
userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx)
|
||||
if sessionErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr)
|
||||
}
|
||||
|
||||
// If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery
|
||||
if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" {
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host)
|
||||
ctx.Response.Header.Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL))
|
||||
return nil, nil, fmt.Errorf("oauth authentication required for mcp access")
|
||||
}
|
||||
|
||||
if userOauthSession != nil {
|
||||
if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") {
|
||||
return h.globalMCPServer, userOauthSession, nil
|
||||
}
|
||||
|
||||
if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil {
|
||||
return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key")
|
||||
}
|
||||
|
||||
vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
|
||||
return vkServer, userOauthSession, nil
|
||||
}
|
||||
|
||||
// Return global MCP server if not enforcing virtual key header and no virtual key is provided
|
||||
if !enforceVK && vk == "" {
|
||||
return h.globalMCPServer, nil, nil
|
||||
}
|
||||
|
||||
if vk == "" {
|
||||
return nil, nil, fmt.Errorf("virtual key header required to access mcp server")
|
||||
}
|
||||
|
||||
vkServer, ok := h.vkMCPServers[vk]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("virtual key not found")
|
||||
}
|
||||
|
||||
return vkServer, nil, nil
|
||||
}
|
||||
|
||||
// getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth
|
||||
// token from the Authorization header. Returns the session if valid, nil otherwise.
|
||||
func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) {
|
||||
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
|
||||
if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return nil, nil
|
||||
}
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
|
||||
return nil, nil // It's a virtual key, not a per-user OAuth token
|
||||
}
|
||||
|
||||
if h.config.ConfigStore == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token)
|
||||
if err != nil {
|
||||
logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if session == nil {
|
||||
logger.Debug("[mcp/auth] Session not found for token")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check expiry
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func getVKFromRequest(ctx *fasthttp.RequestCtx) string {
|
||||
if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization")))
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) {
|
||||
return token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" {
|
||||
if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) {
|
||||
return apiKey
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
1092
transports/bifrost-http/handlers/middlewares.go
Normal file
1092
transports/bifrost-http/handlers/middlewares.go
Normal file
File diff suppressed because it is too large
Load Diff
2022
transports/bifrost-http/handlers/middlewares_test.go
Normal file
2022
transports/bifrost-http/handlers/middlewares_test.go
Normal file
File diff suppressed because it is too large
Load Diff
320
transports/bifrost-http/handlers/oauth2.go
Normal file
320
transports/bifrost-http/handlers/oauth2.go
Normal file
@@ -0,0 +1,320 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains OAuth 2.0 authentication flow handlers.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/oauth2"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// OAuth2Handler manages HTTP requests for OAuth2 operations
|
||||
type OAuthHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
store *lib.Config
|
||||
oauthProvider *oauth2.OAuth2Provider
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler instance
|
||||
func NewOAuthHandler(oauthProvider *oauth2.OAuth2Provider, client *bifrost.Bifrost, store *lib.Config) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
client: client,
|
||||
store: store,
|
||||
oauthProvider: oauthProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all OAuth-related routes
|
||||
func (h *OAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/api/oauth/callback", lib.ChainMiddlewares(h.handleOAuthCallback, middlewares...))
|
||||
r.GET("/api/oauth/config/{id}/status", lib.ChainMiddlewares(h.getOAuthConfigStatus, middlewares...))
|
||||
r.DELETE("/api/oauth/config/{id}", lib.ChainMiddlewares(h.revokeOAuthConfig, middlewares...))
|
||||
}
|
||||
|
||||
// handleOAuthCallback handles the OAuth provider callback
|
||||
// GET /api/oauth/callback?state=xxx&code=yyy&error=zzz
|
||||
func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) {
|
||||
state := string(ctx.QueryArgs().Peek("state"))
|
||||
code := string(ctx.QueryArgs().Peek("code"))
|
||||
errorParam := string(ctx.QueryArgs().Peek("error"))
|
||||
errorDescription := string(ctx.QueryArgs().Peek("error_description"))
|
||||
|
||||
// Handle authorization denial
|
||||
if errorParam != "" {
|
||||
h.handleCallbackError(ctx, state, errorParam, errorDescription)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required parameters
|
||||
if state == "" || code == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Missing required parameters: state and code")
|
||||
return
|
||||
}
|
||||
|
||||
// Try per-user OAuth runtime flow first (state from oauth_user_sessions table).
|
||||
// This handles the case where an end-user authenticates during inference.
|
||||
sessionToken, perUserErr := h.oauthProvider.CompleteUserOAuthFlow(context.Background(), state, code)
|
||||
if perUserErr != nil && !errors.Is(perUserErr, schemas.ErrOAuth2NotPerUserSession) {
|
||||
// Real per-user error (not "state not found") — don't fall through to admin flow
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Per-user OAuth flow failed: %v", perUserErr))
|
||||
return
|
||||
}
|
||||
if perUserErr == nil && sessionToken != "" {
|
||||
// Consent flow: session token is a flow proxy ("flow:<flowID>:<mcpClientID>").
|
||||
// Redirect back to the MCPs consent page so the user can continue.
|
||||
if strings.HasPrefix(sessionToken, "flow:") {
|
||||
rest := strings.TrimPrefix(sessionToken, "flow:")
|
||||
flowID := strings.SplitN(rest, ":", 2)[0]
|
||||
mcpsURL := fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID))
|
||||
ctx.Redirect(mcpsURL, fasthttp.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Per-user runtime OAuth flow completed — show success page.
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html")
|
||||
ctx.SetBodyString(oauthSuccessPage(`
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({ type: 'oauth_success' }, window.location.origin);
|
||||
window.close();
|
||||
}
|
||||
`, "Authorization Successful", "You can close this tab."))
|
||||
return
|
||||
}
|
||||
|
||||
// Fall through to standard OAuth flow (handles both admin test logins for
|
||||
// per_user_oauth setup and regular server-level OAuth).
|
||||
if err := h.oauthProvider.CompleteOAuthFlow(context.Background(), state, code); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("OAuth flow completion failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to success page (or close popup)
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html")
|
||||
ctx.SetBodyString(oauthSuccessPage(`
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({ type: 'oauth_success' }, window.location.origin);
|
||||
window.close();
|
||||
}
|
||||
`, "Authorization Successful", "OAuth authorization successful! You can close this window."))
|
||||
}
|
||||
|
||||
// handleCallbackError handles OAuth callback errors
|
||||
func (h *OAuthHandler) handleCallbackError(ctx *fasthttp.RequestCtx, state, errorParam, errorDescription string) {
|
||||
// Update OAuth config status to failed if state is provided
|
||||
if state != "" {
|
||||
oauthConfig, err := h.store.ConfigStore.GetOauthConfigByState(context.Background(), state)
|
||||
if err == nil && oauthConfig != nil {
|
||||
oauthConfig.Status = "failed"
|
||||
h.store.ConfigStore.UpdateOauthConfig(context.Background(), oauthConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// Show error page
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetContentType("text/html")
|
||||
errorMsg := errorParam
|
||||
if errorDescription != "" {
|
||||
errorMsg = fmt.Sprintf("%s: %s", errorParam, errorDescription)
|
||||
}
|
||||
// JSON-encode for safe embedding in JavaScript context (prevents JS injection)
|
||||
jsEscaped, _ := json.Marshal(errorMsg)
|
||||
// HTML-escape for safe embedding in HTML body (prevents HTML injection)
|
||||
htmlEscaped := html.EscapeString(errorMsg)
|
||||
ctx.SetBodyString(oauthErrorPage(string(jsEscaped), htmlEscaped))
|
||||
}
|
||||
|
||||
// getOAuthConfigStatus returns the current status of an OAuth config
|
||||
// GET /api/oauth/config/{id}/status
|
||||
func (h *OAuthHandler) getOAuthConfigStatus(ctx *fasthttp.RequestCtx) {
|
||||
configID := ctx.UserValue("id").(string)
|
||||
|
||||
oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(context.Background(), configID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get OAuth config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "OAuth config not found")
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": oauthConfig.ID,
|
||||
"status": oauthConfig.Status,
|
||||
"created_at": oauthConfig.CreatedAt,
|
||||
"expires_at": oauthConfig.ExpiresAt,
|
||||
}
|
||||
|
||||
if oauthConfig.Status == "authorized" && oauthConfig.TokenID != nil {
|
||||
response["token_id"] = *oauthConfig.TokenID
|
||||
|
||||
// Get token metadata
|
||||
token, err := h.store.ConfigStore.GetOauthTokenByID(context.Background(), *oauthConfig.TokenID)
|
||||
if err == nil && token != nil {
|
||||
response["token_expires_at"] = token.ExpiresAt
|
||||
response["token_scopes"] = token.Scopes
|
||||
}
|
||||
}
|
||||
|
||||
SendJSON(ctx, response)
|
||||
}
|
||||
|
||||
// revokeOAuthConfig revokes an OAuth configuration and its associated token
|
||||
// DELETE /api/oauth/config/{id}
|
||||
func (h *OAuthHandler) revokeOAuthConfig(ctx *fasthttp.RequestCtx) {
|
||||
configID := ctx.UserValue("id").(string)
|
||||
|
||||
if err := h.oauthProvider.RevokeToken(context.Background(), configID); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to revoke OAuth token: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"message": "OAuth token revoked successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthInitiationRequest represents the request to initiate an OAuth flow
|
||||
type OAuthInitiationRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
AuthorizeURL string `json:"authorize_url"`
|
||||
TokenURL string `json:"token_url"`
|
||||
RegistrationURL string `json:"registration_url"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
Scopes []string `json:"scopes"`
|
||||
ServerURL string `json:"server_url"` // For OAuth discovery
|
||||
}
|
||||
|
||||
// InitiateOAuthFlow initiates an OAuth flow and returns the authorization URL
|
||||
// This is called internally by the MCP client creation endpoint
|
||||
func (h *OAuthHandler) InitiateOAuthFlow(ctx context.Context, req OAuthInitiationRequest) (*schemas.OAuth2FlowInitiation, error) {
|
||||
var registrationURL *string
|
||||
if req.RegistrationURL != "" {
|
||||
registrationURL = &req.RegistrationURL
|
||||
}
|
||||
|
||||
config := &schemas.OAuth2Config{
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
AuthorizeURL: req.AuthorizeURL,
|
||||
TokenURL: req.TokenURL,
|
||||
RegistrationURL: registrationURL,
|
||||
RedirectURI: req.RedirectURI,
|
||||
Scopes: req.Scopes,
|
||||
ServerURL: req.ServerURL, // MCP server URL for OAuth discovery
|
||||
}
|
||||
|
||||
return h.oauthProvider.InitiateOAuthFlow(ctx, config)
|
||||
}
|
||||
|
||||
// StorePendingMCPClient stores an MCP client config in the database while waiting for OAuth completion
|
||||
// This supports multi-instance deployments where OAuth callback may hit a different server instance
|
||||
func (h *OAuthHandler) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error {
|
||||
return h.oauthProvider.StorePendingMCPClient(oauthConfigID, mcpClientConfig)
|
||||
}
|
||||
|
||||
// GetPendingMCPClient retrieves a pending MCP client config by oauth_config_id
|
||||
func (h *OAuthHandler) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) {
|
||||
return h.oauthProvider.GetPendingMCPClient(oauthConfigID)
|
||||
}
|
||||
|
||||
// GetPendingMCPClientByState retrieves a pending MCP client config by OAuth state token
|
||||
func (h *OAuthHandler) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) {
|
||||
return h.oauthProvider.GetPendingMCPClientByState(state)
|
||||
}
|
||||
|
||||
// RemovePendingMCPClient removes a pending MCP client after OAuth completion.
|
||||
func (h *OAuthHandler) RemovePendingMCPClient(oauthConfigID string) error {
|
||||
return h.oauthProvider.RemovePendingMCPClient(oauthConfigID)
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the access token for a given oauth_config_id.
|
||||
// Used during per-user OAuth setup to get the admin's temporary token for verification.
|
||||
func (h *OAuthHandler) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) {
|
||||
return h.oauthProvider.GetAccessToken(ctx, oauthConfigID)
|
||||
}
|
||||
|
||||
// oauthSuccessPage renders a Bifrost-themed success HTML page.
|
||||
// extraScript is injected verbatim into a <script> tag (caller is responsible for safety).
|
||||
func oauthSuccessPage(extraScript, title, message string) string {
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>%s</title>
|
||||
<style>%s
|
||||
.icon{font-size:2.5rem;margin-bottom:16px}
|
||||
.msg{font-size:0.9rem;color:oklch(0.552 0.016 285.938);margin-top:8px}
|
||||
</style>
|
||||
<script>%s</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card" style="text-align:center">
|
||||
<div class="icon">✓</div>
|
||||
<h1>%s</h1>
|
||||
<p class="msg">%s</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, html.EscapeString(title), bifrostPageCSS, extraScript, html.EscapeString(title), html.EscapeString(message))
|
||||
}
|
||||
|
||||
// oauthErrorPage renders a Bifrost-themed error HTML page.
|
||||
// jsEscapedError must already be JSON-encoded (with quotes) for safe JS embedding.
|
||||
// htmlError must already be HTML-escaped for safe body embedding.
|
||||
func oauthErrorPage(jsEscapedError, htmlError string) string {
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authorization Failed</title>
|
||||
<style>%s
|
||||
.icon{font-size:2.5rem;margin-bottom:16px;color:oklch(0.50 0.18 27)}
|
||||
.err-msg{font-size:0.9rem;color:oklch(0.552 0.016 285.938);margin-top:8px}
|
||||
.hint{font-size:0.8rem;color:oklch(0.65 0.01 286);margin-top:16px}
|
||||
</style>
|
||||
<script>
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({ type: 'oauth_failed', error: %s }, window.location.origin);
|
||||
window.close();
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card" style="text-align:center">
|
||||
<div class="icon">✗</div>
|
||||
<h1>Authorization Failed</h1>
|
||||
<p class="err-msg">%s</p>
|
||||
<p class="hint">You can close this window.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, bifrostPageCSS, jsEscapedError, htmlError)
|
||||
}
|
||||
|
||||
// jsEscapeString returns a JSON-encoded string (with quotes) safe for embedding in JavaScript.
|
||||
func jsEscapeString(s string) string {
|
||||
b, _ := json.Marshal(s)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// RevokeToken revokes the OAuth token for a given oauth_config_id.
|
||||
// Used during per-user OAuth setup to discard the admin's temporary token after verification.
|
||||
func (h *OAuthHandler) RevokeToken(ctx context.Context, oauthConfigID string) error {
|
||||
return h.oauthProvider.RevokeToken(ctx, oauthConfigID)
|
||||
}
|
||||
643
transports/bifrost-http/handlers/oauth2_consent.go
Normal file
643
transports/bifrost-http/handlers/oauth2_consent.go
Normal file
@@ -0,0 +1,643 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file implements the per-user OAuth consent flow — the intermediate screens
|
||||
// shown between the MCP client's authorize request and the final authorization code
|
||||
// issuance. The flow is:
|
||||
//
|
||||
// 1. GET /oauth/consent?flow_id=xxx → VK input page (HTML)
|
||||
// 2. POST /api/oauth/per-user/consent/vk → validate VK, update PendingFlow, redirect
|
||||
// 3. GET /oauth/consent/mcps?flow_id=xxx → MCPs page (HTML, server-rendered)
|
||||
// 4. POST /api/oauth/per-user/consent/submit → create session + code, redirect to client
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ConsentHandler manages the per-user OAuth consent flow screens.
|
||||
type ConsentHandler struct {
|
||||
store *lib.Config
|
||||
}
|
||||
|
||||
// NewConsentHandler creates a new consent handler instance.
|
||||
func NewConsentHandler(store *lib.Config) *ConsentHandler {
|
||||
return &ConsentHandler{store: store}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the consent flow routes.
|
||||
// All routes are public — no auth middleware — since they are part of the OAuth
|
||||
// flow for unauthenticated users acquiring credentials.
|
||||
func (h *ConsentHandler) RegisterRoutes(r *router.Router) {
|
||||
// HTML pages (GET, served by Go)
|
||||
r.GET("/oauth/consent", h.handleIdentityPage)
|
||||
r.GET("/oauth/consent/mcps", h.handleMCPsPage)
|
||||
|
||||
// API actions (POST)
|
||||
// NOTE: All state-mutating endpoints use POST. CSRF protection relies on the
|
||||
// SameSite=Lax browser-binding cookie (__bifrost_flow_secret) combined with
|
||||
// the flow_id — SameSite=Lax blocks cross-site POST, and the cookie is
|
||||
// HttpOnly+Secure. This is sufficient for the threat model here.
|
||||
r.POST("/api/oauth/per-user/consent/vk", h.handleSubmitVK)
|
||||
r.POST("/api/oauth/per-user/consent/user-id", h.handleSubmitUserID)
|
||||
r.POST("/api/oauth/per-user/consent/skip", h.handleSkip)
|
||||
r.POST("/api/oauth/per-user/consent/submit", h.handleSubmit)
|
||||
}
|
||||
|
||||
// ---------- HTML pages ----------
|
||||
|
||||
// handleIdentityPage renders the identity selection page with three options:
|
||||
// User ID, Virtual Key, or skip (lazy auth when tools are called).
|
||||
// GET /oauth/consent?flow_id=xxx[&error=xxx]
|
||||
func (h *ConsentHandler) handleIdentityPage(ctx *fasthttp.RequestCtx) {
|
||||
flowID := string(ctx.QueryArgs().Peek("flow_id"))
|
||||
errorMsg := string(ctx.QueryArgs().Peek("error"))
|
||||
|
||||
if flowID == "" {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Missing flow_id")
|
||||
return
|
||||
}
|
||||
|
||||
if h.store.ConfigStore == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
|
||||
ctx.SetBodyString("Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString("Failed to load consent flow.")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
|
||||
h.store.Mu.RLock()
|
||||
enforceVK := h.store.ClientConfig.EnforceAuthOnInference
|
||||
h.store.Mu.RUnlock()
|
||||
|
||||
safeFlowID := html.EscapeString(flowID)
|
||||
safeError := html.EscapeString(errorMsg)
|
||||
|
||||
errorBanner := ""
|
||||
if safeError != "" {
|
||||
errorBanner = fmt.Sprintf(`<div class="error-banner">%s</div>`, safeError)
|
||||
}
|
||||
|
||||
skipOption := ""
|
||||
if !enforceVK {
|
||||
skipOption = fmt.Sprintf(`
|
||||
<div class="option">
|
||||
<span class="option-title">Skip for now</span>
|
||||
<span class="option-desc">Connect to services when a tool is called</span>
|
||||
<form action="/api/oauth/per-user/consent/skip" method="POST" style="margin-top:10px">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<button type="submit" class="btn btn-ghost">Skip</button>
|
||||
</form>
|
||||
</div>`, safeFlowID)
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect to Bifrost</title>
|
||||
<style>
|
||||
%s
|
||||
.option{border:1px solid oklch(0.92 0.004 286.32);border-radius:0.5rem;padding:16px 18px;margin-bottom:10px}
|
||||
.option-title{display:block;font-size:0.9rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:2px}
|
||||
.option-desc{display:block;font-size:0.8rem;color:oklch(0.552 0.016 285.938);margin-bottom:12px}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1>Connect to Bifrost</h1>
|
||||
<p class="subtitle">Choose how to identify yourself for this session.</p>
|
||||
<p style="font-size:0.75rem;color:oklch(0.65 0.01 286);margin-bottom:18px">This setup page expires in 15 minutes.</p>
|
||||
%s
|
||||
<div class="option">
|
||||
<span class="option-title">User ID</span>
|
||||
<span class="option-desc">Use a stable identifier — access all available services</span>
|
||||
<form action="/api/oauth/per-user/consent/user-id" method="POST">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<label for="user_id">User ID</label>
|
||||
<input type="text" id="user_id" name="user_id" placeholder="e.g. alice" autocomplete="off" spellcheck="false" autocapitalize="none" autocorrect="off">
|
||||
<button type="submit" class="btn btn-primary">Continue with User ID</button>
|
||||
</form>
|
||||
</div>
|
||||
<div class="option">
|
||||
<span class="option-title">Virtual Key</span>
|
||||
<span class="option-desc">Use a VK — access services within your key's limits</span>
|
||||
<form action="/api/oauth/per-user/consent/vk" method="POST">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<label for="vk">Virtual Key</label>
|
||||
<input type="password" id="vk" name="vk" placeholder="sk-bf-..." autocomplete="off" spellcheck="false" autocapitalize="none">
|
||||
<button type="submit" class="btn btn-primary">Continue with Virtual Key</button>
|
||||
</form>
|
||||
</div>
|
||||
%s
|
||||
</div>
|
||||
</body>
|
||||
</html>`, bifrostPageCSS, errorBanner, safeFlowID, safeFlowID, skipOption))
|
||||
}
|
||||
|
||||
// handleMCPsPage renders the MCP authentication list page.
|
||||
// GET /oauth/consent/mcps?flow_id=xxx
|
||||
func (h *ConsentHandler) handleMCPsPage(ctx *fasthttp.RequestCtx) {
|
||||
flowID := string(ctx.QueryArgs().Peek("flow_id"))
|
||||
|
||||
if flowID == "" {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Missing flow_id")
|
||||
return
|
||||
}
|
||||
|
||||
if h.store.ConfigStore == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusServiceUnavailable)
|
||||
ctx.SetBodyString("Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString("Failed to load consent flow.")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||
ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
|
||||
// Find which MCP clients the user has already authed.
|
||||
// Check both: tokens stored under the flow proxy (connected during this flow)
|
||||
// and tokens already stored under the VK/user identity (connected in a prior flow).
|
||||
completedTokens, err := h.store.ConfigStore.GetOauthUserTokensByGatewaySessionID(ctx, flowID)
|
||||
if err != nil {
|
||||
completedTokens = nil // non-fatal; just show no checkmarks
|
||||
}
|
||||
completedMCPs := make(map[string]bool, len(completedTokens))
|
||||
for _, t := range completedTokens {
|
||||
completedMCPs[t.MCPClientID] = true
|
||||
}
|
||||
|
||||
// Per_user_oauth MCP clients visible to this identity — sorted for deterministic rendering.
|
||||
// When a VK is set on the flow, only show clients that VK is allowed to use.
|
||||
perUserClients := h.store.GetPerUserOAuthMCPClientsForVirtualKey(ctx, strVal(flow.VirtualKeyID))
|
||||
clientIDs := make([]string, 0, len(perUserClients))
|
||||
for id := range perUserClients {
|
||||
clientIDs = append(clientIDs, id)
|
||||
}
|
||||
sort.Strings(clientIDs)
|
||||
|
||||
safeFlowID := html.EscapeString(flowID)
|
||||
|
||||
// Determine if user skipped identity selection.
|
||||
isSkipped := strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == ""
|
||||
|
||||
// Build MCP rows — only show connect buttons if user has an identity.
|
||||
var mcpRows strings.Builder
|
||||
if isSkipped {
|
||||
mcpRows.WriteString(`<p style="color:#6b7280;font-size:14px;">You skipped identity selection. Services will be connected when you first use their tools. Since no identity is attached, your connections will only persist as long as the service keeps the OAuth token active — they will not be remembered across sessions.</p>`)
|
||||
} else {
|
||||
for _, clientID := range clientIDs {
|
||||
clientName := perUserClients[clientID]
|
||||
safeName := html.EscapeString(clientName)
|
||||
|
||||
// Also check if a token already exists under the user's identity (e.g. from a prior LLM gateway auth).
|
||||
alreadyConnected := completedMCPs[clientID]
|
||||
if !alreadyConnected && (strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "") {
|
||||
existing, tokenErr := h.store.ConfigStore.GetOauthUserTokenByIdentity(ctx, strVal(flow.VirtualKeyID), strVal(flow.UserID), "", clientID)
|
||||
if tokenErr != nil {
|
||||
logger.Warn("[consent/mcps] failed to check existing token: mcp_client_id=%s err=%v", clientID, tokenErr)
|
||||
}
|
||||
alreadyConnected = existing != nil
|
||||
}
|
||||
|
||||
if alreadyConnected {
|
||||
mcpRows.WriteString(fmt.Sprintf(`
|
||||
<div class="mcp-row">
|
||||
<div class="mcp-name">%s</div>
|
||||
<span class="badge connected">✓ Connected</span>
|
||||
</div>`, safeName))
|
||||
} else {
|
||||
connectURL := fmt.Sprintf("/api/oauth/per-user/upstream/authorize?mcp_client_id=%s&flow_id=%s",
|
||||
url.QueryEscape(clientID), url.QueryEscape(flowID))
|
||||
mcpRows.WriteString(fmt.Sprintf(`
|
||||
<div class="mcp-row">
|
||||
<div class="mcp-name">%s</div>
|
||||
<a class="badge connect" href="%s">Connect</a>
|
||||
</div>`, safeName, html.EscapeString(connectURL)))
|
||||
}
|
||||
}
|
||||
if len(perUserClients) == 0 {
|
||||
mcpRows.WriteString(`<p style="color:#6b7280;font-size:14px;">No MCP services require authentication.</p>`)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect Your Apps — Bifrost</title>
|
||||
<style>
|
||||
%s
|
||||
.mcp-row{display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:1px solid oklch(0.92 0.004 286.32)}
|
||||
.mcp-row:last-of-type{border-bottom:none}
|
||||
.mcp-name{font-size:0.9rem;font-weight:500;color:oklch(0.141 0.005 285.823)}
|
||||
.badge{font-size:0.8rem;font-weight:500;padding:4px 12px;border-radius:20px;text-decoration:none;display:inline-block}
|
||||
.badge.connected{background:oklch(0.95 0.05 160);color:oklch(0.35 0.08 160)}
|
||||
.badge.connect{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0);cursor:pointer;
|
||||
padding:8px 18px;border-radius:0.5rem;font-weight:500;
|
||||
transition:background .15s}
|
||||
.badge.connect:hover{background:oklch(0.43 0.1049 165.61)}
|
||||
.mcp-list{margin-bottom:4px}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1>Connect Your Apps</h1>
|
||||
<p class="subtitle">Authenticate with the services below to enable their tools.</p>
|
||||
<p style="font-size:0.75rem;color:oklch(0.65 0.01 286);margin-bottom:18px">This setup page expires in 15 minutes.</p>
|
||||
<div class="mcp-list">%s</div>
|
||||
<form action="/api/oauth/per-user/consent/submit" method="POST" style="margin-top:24px">
|
||||
<input type="hidden" name="flow_id" value="%s">
|
||||
<button type="submit" class="btn btn-primary">Finish Setup</button>
|
||||
</form>
|
||||
<div style="text-align:center;margin-top:12px">
|
||||
<a href="/oauth/consent?flow_id=%s" style="font-size:0.8rem;color:oklch(0.552 0.016 285.938);text-decoration:none">Change identity</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, bifrostPageCSS, mcpRows.String(), safeFlowID, safeFlowID))
|
||||
}
|
||||
|
||||
// ---------- API action handlers ----------
|
||||
|
||||
// handleSubmitVK validates the submitted Virtual Key, links it to the pending flow,
|
||||
// and redirects to the MCPs page.
|
||||
// POST /api/oauth/per-user/consent/vk (form: flow_id, vk)
|
||||
func (h *ConsentHandler) handleSubmitVK(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
vkValue := strings.TrimSpace(string(ctx.FormValue("vk")))
|
||||
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
if vkValue == "" {
|
||||
redirectToIdentityPage(ctx, flowID, "Please enter a Virtual Key.")
|
||||
return
|
||||
}
|
||||
|
||||
vk, err := h.store.ConfigStore.GetVirtualKeyByValue(ctx, vkValue)
|
||||
if err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to validate Virtual Key. Please try again.")
|
||||
return
|
||||
}
|
||||
if vk == nil || !vk.IsActive {
|
||||
redirectToIdentityPage(ctx, flowID, "Virtual Key not found or inactive. Please check and try again.")
|
||||
return
|
||||
}
|
||||
|
||||
flow.VirtualKeyID = &vk.ID
|
||||
flow.UserID = nil // Clear other identity to keep selection exclusive
|
||||
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to save Virtual Key. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleSubmitUserID links a user-supplied User ID to the pending flow and proceeds to MCPs page.
|
||||
// SECURITY: The User ID is self-declared (typed in a form) with no server-side verification.
|
||||
// This matches the trust model of X-Bf-User-Id in the LLM gateway path. Deployments requiring
|
||||
// verified identity should use Virtual Keys or an auth layer in front of Bifrost.
|
||||
// POST /api/oauth/per-user/consent/user-id (form: flow_id, user_id)
|
||||
func (h *ConsentHandler) handleSubmitUserID(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
userID := strings.TrimSpace(string(ctx.FormValue("user_id")))
|
||||
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
redirectToIdentityPage(ctx, flowID, "Please enter a User ID.")
|
||||
return
|
||||
}
|
||||
if len(userID) > 255 {
|
||||
redirectToIdentityPage(ctx, flowID, "User ID is too long (max 255 characters).")
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "" {
|
||||
flow.UserID = &userID
|
||||
}
|
||||
flow.VirtualKeyID = nil // Clear other identity to keep selection exclusive
|
||||
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to save User ID. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleSkip skips identity selection and proceeds directly to the MCPs page.
|
||||
// Upstream services will be connected lazily when tools are first called.
|
||||
// POST /api/oauth/per-user/consent/skip (form: flow_id)
|
||||
func (h *ConsentHandler) handleSkip(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
h.store.Mu.RLock()
|
||||
enforceVK := h.store.ClientConfig.EnforceAuthOnInference
|
||||
h.store.Mu.RUnlock()
|
||||
|
||||
if enforceVK {
|
||||
redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.")
|
||||
return
|
||||
}
|
||||
|
||||
// Clear any previously selected identity so skip truly resets the flow.
|
||||
if strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "" {
|
||||
flow.VirtualKeyID = nil
|
||||
flow.UserID = nil
|
||||
if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
redirectToIdentityPage(ctx, flowID, "Failed to clear identity. Please try again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Skip goes straight to MCPs page; no identity means only lazy auth is available.
|
||||
ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleSubmit finalises the consent flow:
|
||||
// 1. Creates a real Bifrost session (TablePerUserOAuthSession)
|
||||
// 2. Migrates upstream tokens from the flow proxy to the real session
|
||||
// 3. Issues a TablePerUserOAuthCode
|
||||
// 4. Deletes the PendingFlow
|
||||
// 5. Redirects to the original MCP client callback URL with code + state
|
||||
//
|
||||
// POST /api/oauth/per-user/consent/submit (form: flow_id)
|
||||
func (h *ConsentHandler) handleSubmit(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
flowID := string(ctx.FormValue("flow_id"))
|
||||
if flowID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required")
|
||||
return
|
||||
}
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow")
|
||||
return
|
||||
}
|
||||
if flow == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid consent flow")
|
||||
return
|
||||
}
|
||||
if time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Consent flow has expired. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
|
||||
// Server-side enforcement: reject if identity is required but not provided.
|
||||
h.store.Mu.RLock()
|
||||
enforceAuth := h.store.ClientConfig.EnforceAuthOnInference
|
||||
h.store.Mu.RUnlock()
|
||||
if enforceAuth && strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" {
|
||||
redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.")
|
||||
return
|
||||
}
|
||||
|
||||
// 1. Generate session credentials.
|
||||
accessToken, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate session token")
|
||||
return
|
||||
}
|
||||
refreshToken, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
session := &tables.TablePerUserOAuthSession{
|
||||
ID: uuid.New().String(),
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: flow.ClientID,
|
||||
VirtualKeyID: flow.VirtualKeyID,
|
||||
UserID: flow.UserID,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
// 2. Generate authorization code.
|
||||
code, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate authorization code")
|
||||
return
|
||||
}
|
||||
codeRecord := &tables.TablePerUserOAuthCode{
|
||||
ID: uuid.New().String(),
|
||||
Code: code,
|
||||
ClientID: flow.ClientID,
|
||||
RedirectURI: flow.RedirectURI,
|
||||
CodeChallenge: flow.CodeChallenge,
|
||||
SessionID: session.ID, // Links token endpoint to this session so it can return the same access token
|
||||
// Scopes intentionally omitted: the consent flow has no scope selection step.
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
|
||||
// 3. Atomically consume the pending flow, create session, and create auth code.
|
||||
// If another concurrent request already consumed the flow, rowsAffected will be 0.
|
||||
rowsAffected, err := h.store.ConfigStore.FinalizePerUserOAuthConsent(ctx, flowID, session, codeRecord)
|
||||
if err != nil {
|
||||
if errors.Is(err, schemas.ErrPerUserOAuthPendingFlowExpired) {
|
||||
SendError(ctx, fasthttp.StatusGone, "Consent flow has expired. Please restart the authentication process.")
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to finalize consent flow")
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
SendError(ctx, fasthttp.StatusConflict, "Consent flow has already been submitted")
|
||||
return
|
||||
}
|
||||
logger.Debug("[consent/submit] session created: session_id=%s flow_id=%s", session.ID, flowID)
|
||||
|
||||
// 4. Migrate upstream tokens from flow proxy sessions to real session (non-fatal).
|
||||
if err := h.store.ConfigStore.TransferOauthUserTokensFromGatewaySession(ctx, flowID, accessToken, strVal(flow.VirtualKeyID), strVal(flow.UserID)); err != nil {
|
||||
// Non-fatal: tokens can be re-acquired on first tool use.
|
||||
logger.Warn("[consent/submit] failed to transfer upstream tokens: flow_id=%s err=%v", flowID, err)
|
||||
}
|
||||
|
||||
// 5. Redirect to MCP client callback with code + original state.
|
||||
redirectURL, err := url.Parse(flow.RedirectURI)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Invalid redirect URI in pending flow")
|
||||
return
|
||||
}
|
||||
q := redirectURL.Query()
|
||||
q.Set("code", code)
|
||||
if flow.State != "" {
|
||||
q.Set("state", flow.State)
|
||||
}
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
|
||||
ctx.Redirect(redirectURL.String(), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// ---------- helpers ----------
|
||||
|
||||
// bifrostPageCSS is the shared inline CSS for all Go-rendered consent/callback pages.
|
||||
// It mirrors Bifrost's UI design tokens: teal primary, zinc palette, Geist font stack.
|
||||
const bifrostPageCSS = `
|
||||
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
|
||||
body{font-family:"Geist",system-ui,-apple-system,sans-serif;font-size:0.95rem;
|
||||
line-height:1.5;background:#f4f4f5;color:oklch(0.141 0.005 285.823);
|
||||
display:flex;align-items:center;justify-content:center;min-height:100vh;
|
||||
-webkit-font-smoothing:antialiased}
|
||||
.card{background:#fff;border:1px solid oklch(0.92 0.004 286.32);border-radius:12px;
|
||||
padding:40px;width:100%;max-width:480px}
|
||||
h1{font-size:1.25rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:6px}
|
||||
.subtitle{font-size:0.825rem;color:oklch(0.552 0.016 285.938);line-height:1.5;margin-bottom:24px}
|
||||
label{display:block;font-size:0.825rem;font-weight:500;color:oklch(0.141 0.005 285.823);margin-bottom:5px}
|
||||
input[type=text],input[type=password]{width:100%;padding:8px 12px;border:1px solid oklch(0.92 0.004 286.32);
|
||||
border-radius:0.5rem;font-size:0.875rem;outline:none;
|
||||
transition:border-color .15s,box-shadow .15s;margin-bottom:10px;
|
||||
background:#fff;color:oklch(0.141 0.005 285.823)}
|
||||
input[type=text]:focus,input[type=password]:focus{border-color:oklch(0.5081 0.1049 165.61);
|
||||
box-shadow:0 0 0 3px oklch(0.5081 0.1049 165.61 / 0.15)}
|
||||
.btn{display:block;width:100%;padding:9px 16px;border-radius:0.5rem;font-size:0.875rem;
|
||||
font-weight:500;cursor:pointer;border:none;text-align:center;text-decoration:none;
|
||||
transition:background .15s;font-family:inherit}
|
||||
.btn-primary{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0)}
|
||||
.btn-primary:hover{background:oklch(0.43 0.1049 165.61)}
|
||||
.btn-ghost{background:transparent;border:1px solid oklch(0.92 0.004 286.32);
|
||||
color:oklch(0.552 0.016 285.938);display:inline-block;width:auto;padding:8px 16px}
|
||||
.btn-ghost:hover{background:#f4f4f5}
|
||||
.error-banner{background:oklch(0.97 0.02 27);border:1px solid oklch(0.88 0.06 27);
|
||||
border-radius:0.5rem;padding:12px 14px;margin-bottom:18px;
|
||||
color:oklch(0.50 0.18 27);font-size:0.825rem}
|
||||
`
|
||||
|
||||
// redirectToIdentityPage redirects to the identity selection page with an error message.
|
||||
func redirectToIdentityPage(ctx *fasthttp.RequestCtx, flowID, errorMsg string) {
|
||||
u := fmt.Sprintf("/oauth/consent?flow_id=%s&error=%s",
|
||||
url.QueryEscape(flowID), url.QueryEscape(errorMsg))
|
||||
ctx.Redirect(u, fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// strVal safely dereferences a *string, returning "" for nil.
|
||||
func strVal(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
93
transports/bifrost-http/handlers/oauth2_metadata.go
Normal file
93
transports/bifrost-http/handlers/oauth2_metadata.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file implements OAuth 2.0 metadata discovery endpoints per RFC 9728
|
||||
// (Protected Resource Metadata) and RFC 8414 (Authorization Server Metadata).
|
||||
// These endpoints enable MCP-spec-compliant clients (like Claude Code) to
|
||||
// automatically discover Bifrost's OAuth configuration and authenticate.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// OAuthMetadataHandler serves OAuth 2.0 discovery metadata endpoints.
|
||||
// It provides the Protected Resource Metadata (RFC 9728) and Authorization
|
||||
// Server Metadata (RFC 8414) that MCP clients use to discover how to
|
||||
// authenticate with Bifrost's MCP server endpoint.
|
||||
type OAuthMetadataHandler struct {
|
||||
store *lib.Config
|
||||
}
|
||||
|
||||
// NewOAuthMetadataHandler creates a new OAuth metadata handler instance.
|
||||
func NewOAuthMetadataHandler(store *lib.Config) *OAuthMetadataHandler {
|
||||
return &OAuthMetadataHandler{store: store}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the well-known metadata discovery routes.
|
||||
// These routes do NOT go through auth middleware since they must be
|
||||
// accessible to unauthenticated clients during OAuth discovery.
|
||||
func (h *OAuthMetadataHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
// RFC 9728: Protected Resource Metadata
|
||||
r.GET("/.well-known/oauth-protected-resource", lib.ChainMiddlewares(h.handleProtectedResourceMetadata, middlewares...))
|
||||
// RFC 8414: Authorization Server Metadata
|
||||
r.GET("/.well-known/oauth-authorization-server", lib.ChainMiddlewares(h.handleAuthorizationServerMetadata, middlewares...))
|
||||
}
|
||||
|
||||
// handleProtectedResourceMetadata serves the Protected Resource Metadata
|
||||
// document per RFC 9728. MCP clients fetch this after receiving a 401 response
|
||||
// to discover which authorization server(s) protect the MCP resource.
|
||||
//
|
||||
// GET /.well-known/oauth-protected-resource
|
||||
func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.RequestCtx) {
|
||||
if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"resource": baseURL + "/mcp",
|
||||
"authorization_servers": []string{baseURL},
|
||||
"scopes_supported": []string{"mcp:read", "mcp:write"},
|
||||
"bearer_methods_supported": []string{"header"},
|
||||
})
|
||||
}
|
||||
|
||||
// handleAuthorizationServerMetadata serves the Authorization Server Metadata
|
||||
// document per RFC 8414. MCP clients use this to discover Bifrost's OAuth
|
||||
// endpoints (authorize, token, register) and supported capabilities.
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server
|
||||
func (h *OAuthMetadataHandler) handleAuthorizationServerMetadata(ctx *fasthttp.RequestCtx) {
|
||||
if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"issuer": baseURL,
|
||||
"authorization_endpoint": baseURL + "/api/oauth/per-user/authorize",
|
||||
"token_endpoint": baseURL + "/api/oauth/per-user/token",
|
||||
"registration_endpoint": baseURL + "/api/oauth/per-user/register",
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"token_endpoint_auth_methods_supported": []string{"none"},
|
||||
"scopes_supported": []string{"mcp:read", "mcp:write"},
|
||||
})
|
||||
}
|
||||
577
transports/bifrost-http/handlers/oauth2_per_user.go
Normal file
577
transports/bifrost-http/handlers/oauth2_per_user.go
Normal file
@@ -0,0 +1,577 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file implements Bifrost's OAuth 2.1 Authorization Server for per-user MCP
|
||||
// authentication. It provides Dynamic Client Registration (RFC 7591), Authorization
|
||||
// Code flow with PKCE, and token issuance. MCP clients (Claude Code, IDEs) use
|
||||
// these endpoints to authenticate users before accessing Bifrost's /mcp endpoint.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// PerUserOAuthHandler implements Bifrost's OAuth 2.1 Authorization Server.
|
||||
// It handles dynamic client registration, authorization code issuance with PKCE,
|
||||
// and token exchange for MCP per-user authentication.
|
||||
type PerUserOAuthHandler struct {
|
||||
store *lib.Config
|
||||
}
|
||||
|
||||
// NewPerUserOAuthHandler creates a new per-user OAuth handler instance.
|
||||
func NewPerUserOAuthHandler(store *lib.Config) *PerUserOAuthHandler {
|
||||
return &PerUserOAuthHandler{store: store}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the per-user OAuth authorization server routes.
|
||||
// These routes do NOT go through auth middleware since they are part of the
|
||||
// OAuth flow that unauthenticated clients use to obtain tokens.
|
||||
func (h *PerUserOAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.POST("/api/oauth/per-user/register", lib.ChainMiddlewares(h.handleDynamicClientRegistration, middlewares...))
|
||||
r.GET("/api/oauth/per-user/authorize", lib.ChainMiddlewares(h.handleAuthorize, middlewares...))
|
||||
r.POST("/api/oauth/per-user/token", lib.ChainMiddlewares(h.handleToken, middlewares...))
|
||||
r.GET("/api/oauth/per-user/upstream/authorize", lib.ChainMiddlewares(h.handleUpstreamAuthorize, middlewares...))
|
||||
}
|
||||
|
||||
// handleDynamicClientRegistration handles OAuth 2.0 Dynamic Client Registration
|
||||
// per RFC 7591. MCP clients register themselves to obtain a client_id.
|
||||
//
|
||||
// POST /api/oauth/per-user/register
|
||||
func (h *PerUserOAuthHandler) handleDynamicClientRegistration(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth registration unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid registration request: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.RedirectURIs) == 0 {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "redirect_uris is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate client_id
|
||||
clientID := uuid.New().String()
|
||||
|
||||
// Serialize arrays
|
||||
redirectURIsJSON, _ := json.Marshal(req.RedirectURIs)
|
||||
grantTypes := req.GrantTypes
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = []string{"authorization_code"}
|
||||
}
|
||||
grantTypesJSON, _ := json.Marshal(grantTypes)
|
||||
|
||||
client := &tables.TablePerUserOAuthClient{
|
||||
ID: uuid.New().String(),
|
||||
ClientID: clientID,
|
||||
ClientName: req.ClientName,
|
||||
RedirectURIs: string(redirectURIsJSON),
|
||||
GrantTypes: string(grantTypesJSON),
|
||||
}
|
||||
|
||||
if err := h.store.ConfigStore.CreatePerUserOAuthClient(ctx, client); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Return RFC 7591 response
|
||||
ctx.SetStatusCode(fasthttp.StatusCreated)
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"client_id": clientID,
|
||||
"client_name": req.ClientName,
|
||||
"redirect_uris": req.RedirectURIs,
|
||||
"grant_types": grantTypes,
|
||||
"response_types": req.ResponseTypes,
|
||||
"token_endpoint_auth_method": "none",
|
||||
})
|
||||
}
|
||||
|
||||
// handleAuthorize handles the OAuth 2.1 authorization endpoint.
|
||||
// Instead of issuing a code immediately, it validates the request parameters,
|
||||
// creates a PendingFlow record, and redirects the user to the consent screen.
|
||||
// The code is only issued after the user completes the consent flow (VK + MCP auths).
|
||||
//
|
||||
// GET /api/oauth/per-user/authorize?response_type=code&client_id=xxx&redirect_uri=xxx&code_challenge=xxx&code_challenge_method=S256[&state=xxx]
|
||||
func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth authorization unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract parameters
|
||||
responseType := string(ctx.QueryArgs().Peek("response_type"))
|
||||
clientID := string(ctx.QueryArgs().Peek("client_id"))
|
||||
redirectURI := string(ctx.QueryArgs().Peek("redirect_uri"))
|
||||
state := string(ctx.QueryArgs().Peek("state"))
|
||||
codeChallenge := string(ctx.QueryArgs().Peek("code_challenge"))
|
||||
codeChallengeMethod := string(ctx.QueryArgs().Peek("code_challenge_method"))
|
||||
|
||||
// Validate required parameters
|
||||
if responseType != "code" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "response_type must be 'code'")
|
||||
return
|
||||
}
|
||||
if clientID == "" || redirectURI == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "client_id and redirect_uri are required")
|
||||
return
|
||||
}
|
||||
if codeChallenge == "" || codeChallengeMethod != "S256" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "PKCE is required: code_challenge and code_challenge_method=S256")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate client exists and redirect_uri is registered
|
||||
client, err := h.store.ConfigStore.GetPerUserOAuthClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to validate client: %v", err))
|
||||
return
|
||||
}
|
||||
if client == nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Unknown client_id")
|
||||
return
|
||||
}
|
||||
var allowedURIs []string
|
||||
json.Unmarshal([]byte(client.RedirectURIs), &allowedURIs)
|
||||
uriAllowed := false
|
||||
for _, allowed := range allowedURIs {
|
||||
if allowed == redirectURI {
|
||||
uriAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !uriAllowed {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "redirect_uri not registered for this client")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a browser-binding secret so only the initiating browser can resume this flow.
|
||||
browserSecret, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate browser secret")
|
||||
return
|
||||
}
|
||||
browserSecretHash := fmt.Sprintf("%x", sha256.Sum256([]byte(browserSecret)))
|
||||
|
||||
// Create a PendingFlow to carry OAuth params through the consent screen.
|
||||
flow := &tables.TablePerUserOAuthPendingFlow{
|
||||
ID: uuid.New().String(),
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
CodeChallenge: codeChallenge,
|
||||
State: state,
|
||||
BrowserSecretHash: browserSecretHash,
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
if err := h.store.ConfigStore.CreatePerUserOAuthPendingFlow(ctx, flow); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create pending flow: %v", err))
|
||||
return
|
||||
}
|
||||
logger.Debug("[oauth/authorize] PendingFlow created: flow_id=%s client_id=%s", flow.ID, clientID)
|
||||
|
||||
// Set HttpOnly cookie binding this flow to the current browser.
|
||||
var cookie fasthttp.Cookie
|
||||
cookie.SetKey("__bifrost_flow_secret")
|
||||
cookie.SetValue(browserSecret)
|
||||
cookie.SetPath("/")
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
isSecure := ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https"
|
||||
cookie.SetSecure(isSecure)
|
||||
cookie.SetMaxAge(15 * 60) // 15 minutes, matching flow TTL
|
||||
ctx.Response.Header.SetCookie(&cookie)
|
||||
|
||||
// Redirect to consent screen with flow_id (relative path — stays on current origin).
|
||||
consentURL := fmt.Sprintf("/oauth/consent?flow_id=%s", url.QueryEscape(flow.ID))
|
||||
ctx.Redirect(consentURL, fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// handleToken handles the OAuth 2.1 token endpoint.
|
||||
// It validates the authorization code + PKCE verifier and issues access/refresh tokens.
|
||||
//
|
||||
// POST /api/oauth/per-user/token
|
||||
func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth token endpoint unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if len(h.store.GetPerUserOAuthMCPClients()) == 0 {
|
||||
sendStringError(ctx, fasthttp.StatusNotFound, "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse form-encoded body
|
||||
grantType := string(ctx.FormValue("grant_type"))
|
||||
code := string(ctx.FormValue("code"))
|
||||
redirectURI := string(ctx.FormValue("redirect_uri"))
|
||||
clientID := string(ctx.FormValue("client_id"))
|
||||
codeVerifier := string(ctx.FormValue("code_verifier"))
|
||||
|
||||
if grantType != "authorization_code" {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "unsupported_grant_type", "Only authorization_code grant is supported")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" || codeVerifier == "" {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_request", "code and code_verifier are required")
|
||||
return
|
||||
}
|
||||
|
||||
// Atomically claim authorization code (prevents concurrent redemption)
|
||||
codeRecord, err := h.store.ConfigStore.ClaimPerUserOAuthCode(ctx, code)
|
||||
if err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to validate code")
|
||||
return
|
||||
}
|
||||
if codeRecord == nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Invalid or already used authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate code is not expired
|
||||
if time.Now().After(codeRecord.ExpiresAt) {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Authorization code expired")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate client_id if provided — some public clients omit it (RFC 6749 §4.1.3 allows
|
||||
// omitting client_id when the client is not authenticating with the server).
|
||||
// The code record already binds the code to the correct client, so this is safe.
|
||||
if clientID != "" && codeRecord.ClientID != clientID {
|
||||
logger.Debug("[oauth/token] client_id mismatch: code_client=%s request_client=%s", codeRecord.ClientID, clientID)
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "client_id mismatch")
|
||||
return
|
||||
}
|
||||
// Use the client_id from the code record as the authoritative value.
|
||||
clientID = codeRecord.ClientID
|
||||
|
||||
// Validate redirect_uri matches
|
||||
if redirectURI != "" && codeRecord.RedirectURI != redirectURI {
|
||||
logger.Debug("[oauth/token] redirect_uri mismatch: code=%s request=%s", codeRecord.RedirectURI, redirectURI)
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "redirect_uri mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate PKCE: SHA256(code_verifier) must match code_challenge
|
||||
verifierHash := sha256.Sum256([]byte(codeVerifier))
|
||||
computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
|
||||
if computedChallenge != codeRecord.CodeChallenge {
|
||||
logger.Debug("[oauth/token] PKCE verification failed")
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "PKCE verification failed")
|
||||
return
|
||||
}
|
||||
|
||||
// If the code was issued by the consent flow (handleSubmit), the session already exists
|
||||
// with the upstream tokens transferred to it. Reuse that session's access token so the
|
||||
// client receives the token that the upstream (Notion, GitHub, etc.) tokens are linked to.
|
||||
var accessToken string
|
||||
var expiresAt time.Time
|
||||
|
||||
if codeRecord.SessionID != "" {
|
||||
existingSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, codeRecord.SessionID)
|
||||
if err != nil {
|
||||
logger.Info("[oauth/token] Failed to load existing session: session_id=%s err=%v", codeRecord.SessionID, err)
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to load session")
|
||||
return
|
||||
}
|
||||
if existingSession == nil {
|
||||
logger.Info("[oauth/token] Existing session not found: session_id=%s", codeRecord.SessionID)
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Session not found")
|
||||
return
|
||||
}
|
||||
if !existingSession.ExpiresAt.After(time.Now()) {
|
||||
sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Session expired")
|
||||
return
|
||||
}
|
||||
accessToken = existingSession.AccessToken
|
||||
expiresAt = existingSession.ExpiresAt
|
||||
logger.Debug("[oauth/token] reusing consent session: session_id=%s", existingSession.ID)
|
||||
} else {
|
||||
// Fallback: no linked session (legacy path) — create a new one.
|
||||
var newAccessToken, newRefreshToken string
|
||||
newAccessToken, err = generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate access token")
|
||||
return
|
||||
}
|
||||
newRefreshToken, err = generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate refresh token")
|
||||
return
|
||||
}
|
||||
expiresAt = time.Now().Add(24 * time.Hour)
|
||||
newSession := &tables.TablePerUserOAuthSession{
|
||||
ID: uuid.New().String(),
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
ClientID: clientID,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
if err := h.store.ConfigStore.CreatePerUserOAuthSession(ctx, newSession); err != nil {
|
||||
sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to create session")
|
||||
return
|
||||
}
|
||||
accessToken = newAccessToken
|
||||
logger.Debug("[oauth/token] created new session (legacy path): session_id=%s", newSession.ID)
|
||||
}
|
||||
// Return OAuth token response
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int(time.Until(expiresAt).Seconds()),
|
||||
"scope": codeRecord.Scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// sendOAuthError sends an OAuth 2.0 error response per RFC 6749 Section 5.2.
|
||||
func sendOAuthError(ctx *fasthttp.RequestCtx, statusCode int, errorCode, description string) {
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetStatusCode(statusCode)
|
||||
resp, _ := json.Marshal(map[string]string{
|
||||
"error": errorCode,
|
||||
"error_description": description,
|
||||
})
|
||||
ctx.SetBody(resp)
|
||||
}
|
||||
|
||||
func sendStringError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
|
||||
ctx.SetContentType("text/plain")
|
||||
ctx.SetStatusCode(statusCode)
|
||||
ctx.SetBodyString(message)
|
||||
}
|
||||
|
||||
// generateOpaqueToken generates a cryptographically secure random token.
|
||||
// validateFlowBrowserSecret checks that the request carries the __bifrost_flow_secret
|
||||
// cookie matching the hash stored on the pending flow. Returns true if valid.
|
||||
func validateFlowBrowserSecret(ctx *fasthttp.RequestCtx, flow *tables.TablePerUserOAuthPendingFlow) bool {
|
||||
if flow.BrowserSecretHash == "" {
|
||||
// Legacy flow without browser binding — allow for backwards compatibility.
|
||||
return true
|
||||
}
|
||||
secret := ctx.Request.Header.Cookie("__bifrost_flow_secret")
|
||||
if len(secret) == 0 {
|
||||
return false
|
||||
}
|
||||
hash := fmt.Sprintf("%x", sha256.Sum256(secret))
|
||||
return hash == flow.BrowserSecretHash
|
||||
}
|
||||
|
||||
func generateOpaqueToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// handleUpstreamAuthorize handles the upstream OAuth proxy for per-user OAuth.
|
||||
// When a user needs to authenticate with an upstream MCP server (e.g., Notion),
|
||||
// this endpoint redirects them to the upstream provider's OAuth authorize URL.
|
||||
// After the user authenticates, the callback stores their upstream token linked
|
||||
// to either their Bifrost session (runtime flow) or a PendingFlow (consent flow).
|
||||
//
|
||||
// Runtime flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&session=xxx
|
||||
// Consent flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&flow_id=xxx
|
||||
func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) {
|
||||
if h.store.ConfigStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth upstream authorization unavailable: config store is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
mcpClientID := string(ctx.QueryArgs().Peek("mcp_client_id"))
|
||||
sessionID := string(ctx.QueryArgs().Peek("session"))
|
||||
flowID := string(ctx.QueryArgs().Peek("flow_id"))
|
||||
|
||||
if mcpClientID == "" || (sessionID == "" && flowID == "") {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "mcp_client_id and either session or flow_id are required")
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve identity depending on whether this is a runtime session or a consent flow.
|
||||
var virtualKeyID, userID, proxySessionToken, gatewaySessionID string
|
||||
if flowID != "" {
|
||||
// Consent flow: use the pending flow for identity and proxy token.
|
||||
flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID)
|
||||
if err != nil || flow == nil || time.Now().After(flow.ExpiresAt) {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired consent flow")
|
||||
return
|
||||
}
|
||||
if !validateFlowBrowserSecret(ctx, flow) {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session")
|
||||
return
|
||||
}
|
||||
if strVal(flow.VirtualKeyID) != "" {
|
||||
virtualKeyID = *flow.VirtualKeyID
|
||||
}
|
||||
if strVal(flow.UserID) != "" {
|
||||
userID = *flow.UserID
|
||||
}
|
||||
// Use a prefixed flow token so the callback can detect the consent path.
|
||||
// Include mcpClientID to avoid unique constraint violations when multiple
|
||||
// MCP services are connected in the same consent flow.
|
||||
proxySessionToken = "flow:" + flowID + ":" + mcpClientID
|
||||
gatewaySessionID = flowID
|
||||
} else {
|
||||
// Runtime flow: validate the existing Bifrost session.
|
||||
bifrostSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID)
|
||||
if err != nil || bifrostSession == nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
if !bifrostSession.ExpiresAt.After(time.Now()) {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
virtualKeyID = strVal(bifrostSession.VirtualKeyID)
|
||||
userID = strVal(bifrostSession.UserID)
|
||||
proxySessionToken = "runtime:" + sessionID + ":" + mcpClientID
|
||||
gatewaySessionID = sessionID
|
||||
}
|
||||
|
||||
// Look up the MCP client config to get the template OAuth config.
|
||||
mcpClient, err := h.store.ConfigStore.GetMCPClientByID(ctx, mcpClientID)
|
||||
if err != nil || mcpClient == nil {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "MCP client not found")
|
||||
return
|
||||
}
|
||||
if mcpClient.AuthType != string(schemas.MCPAuthTypePerUserOauth) {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "MCP client does not use per-user OAuth")
|
||||
return
|
||||
}
|
||||
if mcpClient.OauthConfigID == nil || *mcpClient.OauthConfigID == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "MCP client has no OAuth configuration")
|
||||
return
|
||||
}
|
||||
|
||||
// Load template OAuth config (has upstream authorize_url, client_id, etc.)
|
||||
templateConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, *mcpClient.OauthConfigID)
|
||||
if err != nil || templateConfig == nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load OAuth template config")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE challenge for upstream.
|
||||
codeVerifier, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate PKCE verifier")
|
||||
return
|
||||
}
|
||||
verifierHash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
|
||||
|
||||
// Generate state for upstream.
|
||||
state, err := generateOpaqueToken(32)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate state token")
|
||||
return
|
||||
}
|
||||
|
||||
// Build redirect URI (Bifrost's callback endpoint).
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host)
|
||||
var vkId *string
|
||||
if virtualKeyID != "" {
|
||||
vkId = &virtualKeyID
|
||||
}
|
||||
var uid *string
|
||||
if userID != "" {
|
||||
uid = &userID
|
||||
}
|
||||
// Store upstream OAuth session linking state → MCP client + identity.
|
||||
upstreamSession := &tables.TableOauthUserSession{
|
||||
ID: uuid.New().String(),
|
||||
MCPClientID: mcpClientID,
|
||||
OauthConfigID: *mcpClient.OauthConfigID,
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
SessionToken: proxySessionToken, // "runtime:xxx" for runtime flow; "flow:xxx" for consent flow
|
||||
GatewaySessionID: gatewaySessionID,
|
||||
VirtualKeyID: vkId,
|
||||
UserID: uid,
|
||||
Status: "pending",
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
}
|
||||
logger.Debug("[oauth/upstream-authorize] creating upstream session: mcp_client=%s flow=%s", mcpClientID, proxySessionToken)
|
||||
if err := h.store.ConfigStore.CreateOauthUserSession(ctx, upstreamSession); err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create upstream OAuth session: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse scopes from template config.
|
||||
var scopes []string
|
||||
if templateConfig.Scopes != "" {
|
||||
json.Unmarshal([]byte(templateConfig.Scopes), &scopes)
|
||||
}
|
||||
|
||||
// Build upstream authorize URL with PKCE.
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", templateConfig.ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
if len(scopes) > 0 {
|
||||
params.Set("scope", strings.Join(scopes, " "))
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(templateConfig.AuthorizeURL)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Invalid upstream authorize URL")
|
||||
return
|
||||
}
|
||||
existing := baseURL.Query()
|
||||
for k, vals := range params {
|
||||
for _, v := range vals {
|
||||
existing.Set(k, v)
|
||||
}
|
||||
}
|
||||
baseURL.RawQuery = existing.Encode()
|
||||
ctx.Redirect(baseURL.String(), fasthttp.StatusFound)
|
||||
}
|
||||
|
||||
// Ensure unused imports are referenced.
|
||||
var _ = html.EscapeString
|
||||
var _ configstore.ConfigStore
|
||||
491
transports/bifrost-http/handlers/plugins.go
Normal file
491
transports/bifrost-http/handlers/plugins.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/plugins"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type PluginsLoader interface {
|
||||
ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error
|
||||
RemovePlugin(ctx context.Context, name string) error
|
||||
GetPluginStatus(ctx context.Context) map[string]schemas.PluginStatus
|
||||
}
|
||||
|
||||
// PluginsHandler is the handler for the plugins API
|
||||
type PluginsHandler struct {
|
||||
configStore configstore.ConfigStore
|
||||
pluginsLoader PluginsLoader
|
||||
}
|
||||
|
||||
// NewPluginsHandler creates a new PluginsHandler
|
||||
func NewPluginsHandler(pluginsLoader PluginsLoader, configStore configstore.ConfigStore) *PluginsHandler {
|
||||
return &PluginsHandler{
|
||||
pluginsLoader: pluginsLoader,
|
||||
configStore: configStore,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CreatePluginRequest is the request body for creating a plugin
|
||||
type CreatePluginRequest struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Config map[string]any `json:"config"`
|
||||
Path *string `json:"path"`
|
||||
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
|
||||
Order *int `json:"order,omitempty"`
|
||||
}
|
||||
|
||||
// UpdatePluginRequest is the request body for updating a plugin
|
||||
type UpdatePluginRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Path *string `json:"path"`
|
||||
Config map[string]any `json:"config"`
|
||||
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
|
||||
Order *int `json:"order,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the routes for the PluginsHandler
|
||||
func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/api/plugins", lib.ChainMiddlewares(h.getPlugins, middlewares...))
|
||||
r.GET("/api/plugins/{name}", lib.ChainMiddlewares(h.getPlugin, middlewares...))
|
||||
r.POST("/api/plugins", lib.ChainMiddlewares(h.createPlugin, middlewares...))
|
||||
r.PUT("/api/plugins/{name}", lib.ChainMiddlewares(h.updatePlugin, middlewares...))
|
||||
r.DELETE("/api/plugins/{name}", lib.ChainMiddlewares(h.deletePlugin, middlewares...))
|
||||
}
|
||||
|
||||
type PluginResponse struct {
|
||||
Name string `json:"name"`
|
||||
ActualName string `json:"actualName"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Config any `json:"config"`
|
||||
IsCustom bool `json:"isCustom"`
|
||||
Path *string `json:"path"`
|
||||
Placement *schemas.PluginPlacement `json:"placement,omitempty"`
|
||||
Order *int `json:"order,omitempty"`
|
||||
Status schemas.PluginStatus `json:"status"`
|
||||
}
|
||||
|
||||
// buildPluginResponse constructs a PluginResponse with status for a given TablePlugin.
|
||||
func (h *PluginsHandler) buildPluginResponse(ctx context.Context, plugin *configstoreTables.TablePlugin) PluginResponse {
|
||||
pluginStatus := schemas.PluginStatus{
|
||||
Name: plugin.Name,
|
||||
Status: schemas.PluginStatusUninitialized,
|
||||
Logs: []string{},
|
||||
}
|
||||
if !plugin.Enabled {
|
||||
pluginStatus.Status = schemas.PluginStatusDisabled
|
||||
} else {
|
||||
for _, status := range h.pluginsLoader.GetPluginStatus(ctx) {
|
||||
if plugin.Name == status.Name {
|
||||
pluginStatus = status
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return PluginResponse{
|
||||
Name: plugin.Name,
|
||||
ActualName: pluginStatus.Name,
|
||||
Enabled: plugin.Enabled,
|
||||
Config: plugin.Config,
|
||||
IsCustom: plugin.IsCustom,
|
||||
Path: plugin.Path,
|
||||
Placement: plugin.Placement,
|
||||
Order: plugin.Order,
|
||||
Status: pluginStatus,
|
||||
}
|
||||
}
|
||||
|
||||
// getPlugins gets all plugins
|
||||
func (h *PluginsHandler) getPlugins(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
pluginStatus := h.pluginsLoader.GetPluginStatus(ctx)
|
||||
finalPlugins := []PluginResponse{}
|
||||
for name, pluginStatus := range pluginStatus {
|
||||
finalPlugins = append(finalPlugins, PluginResponse{
|
||||
Name: pluginStatus.Name,
|
||||
ActualName: name,
|
||||
Enabled: true,
|
||||
Config: map[string]any{},
|
||||
IsCustom: true,
|
||||
Path: nil,
|
||||
Status: pluginStatus,
|
||||
})
|
||||
}
|
||||
SendJSON(ctx, map[string]any{
|
||||
"plugins": finalPlugins,
|
||||
"count": len(finalPlugins),
|
||||
})
|
||||
return
|
||||
}
|
||||
plugins, err := h.configStore.GetPlugins(ctx)
|
||||
if err != nil {
|
||||
logger.Error("failed to get plugins: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugins")
|
||||
return
|
||||
}
|
||||
// Fetching status
|
||||
pluginStatuses := h.pluginsLoader.GetPluginStatus(ctx)
|
||||
// Creating ephemeral struct for the plugins
|
||||
finalPlugins := []PluginResponse{}
|
||||
|
||||
// Iterating over plugin status to get the plugin info
|
||||
for _, plugin := range plugins {
|
||||
pluginStatus := schemas.PluginStatus{
|
||||
Name: plugin.Name,
|
||||
Status: schemas.PluginStatusUninitialized,
|
||||
Logs: []string{},
|
||||
}
|
||||
if !plugin.Enabled {
|
||||
pluginStatus.Status = schemas.PluginStatusDisabled
|
||||
}
|
||||
for _, status := range pluginStatuses {
|
||||
if plugin.Name == status.Name {
|
||||
pluginStatus = status
|
||||
break
|
||||
}
|
||||
}
|
||||
finalPlugins = append(finalPlugins, PluginResponse{
|
||||
Name: plugin.Name,
|
||||
ActualName: pluginStatus.Name,
|
||||
Enabled: plugin.Enabled,
|
||||
Config: plugin.Config,
|
||||
IsCustom: plugin.IsCustom,
|
||||
Path: plugin.Path,
|
||||
Placement: plugin.Placement,
|
||||
Order: plugin.Order,
|
||||
Status: pluginStatus,
|
||||
})
|
||||
}
|
||||
// Creating ephemeral struct
|
||||
SendJSON(ctx, map[string]any{
|
||||
"plugins": finalPlugins,
|
||||
"count": len(finalPlugins),
|
||||
})
|
||||
}
|
||||
|
||||
// getPlugin gets a plugin by name
|
||||
func (h *PluginsHandler) getPlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
pluginStatus := h.pluginsLoader.GetPluginStatus(ctx)
|
||||
pluginInfo := PluginResponse{}
|
||||
for name, pluginStatus := range pluginStatus {
|
||||
if pluginStatus.Name == ctx.UserValue("name") {
|
||||
pluginInfo = PluginResponse{
|
||||
Name: pluginStatus.Name,
|
||||
ActualName: name,
|
||||
Enabled: true,
|
||||
Config: map[string]any{},
|
||||
IsCustom: true,
|
||||
Path: nil,
|
||||
Status: pluginStatus,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, pluginInfo)
|
||||
return
|
||||
}
|
||||
// Safely validate the "name" parameter
|
||||
nameValue := ctx.UserValue("name")
|
||||
if nameValue == nil {
|
||||
logger.Warn("missing required 'name' parameter in request")
|
||||
SendError(ctx, 400, "Missing required 'name' parameter")
|
||||
return
|
||||
}
|
||||
|
||||
name, ok := nameValue.(string)
|
||||
if !ok {
|
||||
logger.Warn("invalid 'name' parameter type, expected string but got %T", nameValue)
|
||||
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
logger.Warn("empty 'name' parameter provided")
|
||||
SendError(ctx, 400, "Empty 'name' parameter not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
plugin, err := h.configStore.GetPlugin(ctx, name)
|
||||
if err != nil {
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
|
||||
return
|
||||
}
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugin")
|
||||
return
|
||||
}
|
||||
SendJSON(ctx, plugin)
|
||||
}
|
||||
|
||||
// createPlugin creates a new plugin
|
||||
func (h *PluginsHandler) createPlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, 400, "Plugins creation is not supported when configstore is disabled")
|
||||
return
|
||||
}
|
||||
var request CreatePluginRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &request); err != nil {
|
||||
logger.Error("failed to unmarshal create plugin request: %v", err)
|
||||
SendError(ctx, 400, "Invalid request body")
|
||||
return
|
||||
}
|
||||
// Validate required fields
|
||||
if request.Name == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Plugin name is required")
|
||||
return
|
||||
}
|
||||
// Validate placement value
|
||||
if request.Placement != nil && *request.Placement != "" &&
|
||||
*request.Placement != schemas.PluginPlacementPreBuiltin &&
|
||||
*request.Placement != schemas.PluginPlacementPostBuiltin {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid placement value. Must be 'pre_builtin' or 'post_builtin'")
|
||||
return
|
||||
}
|
||||
if request.Placement != nil && *request.Placement == "" {
|
||||
request.Placement = nil
|
||||
}
|
||||
// Normalize empty path to nil (treat empty string as built-in plugin)
|
||||
if request.Path != nil && *request.Path == "" {
|
||||
request.Path = nil
|
||||
}
|
||||
// Check if plugin already exists
|
||||
existingPlugin, err := h.configStore.GetPlugin(ctx, request.Name)
|
||||
if err == nil && existingPlugin != nil {
|
||||
SendError(ctx, fasthttp.StatusConflict, "Plugin already exists")
|
||||
return
|
||||
}
|
||||
// Determine if this is a built-in or custom plugin
|
||||
isBuiltin := lib.IsBuiltinPlugin(request.Name)
|
||||
// Built-in plugins should not have a path
|
||||
if isBuiltin && request.Path != nil {
|
||||
request.Path = nil
|
||||
}
|
||||
// Create DB entry first to avoid orphaned in-memory state if DB write fails
|
||||
if err := h.configStore.CreatePlugin(ctx, &configstoreTables.TablePlugin{
|
||||
Name: request.Name,
|
||||
Enabled: request.Enabled,
|
||||
Config: request.Config,
|
||||
Path: request.Path,
|
||||
IsCustom: !isBuiltin,
|
||||
Placement: request.Placement,
|
||||
Order: request.Order,
|
||||
}); err != nil {
|
||||
logger.Error("failed to create plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to create plugin")
|
||||
return
|
||||
}
|
||||
|
||||
// Reload the plugin into memory if it's enabled
|
||||
if request.Enabled {
|
||||
if err := h.pluginsLoader.ReloadPlugin(ctx, request.Name, request.Path, request.Config, request.Placement, request.Order); err != nil {
|
||||
logger.Error("failed to load plugin: %v", err)
|
||||
if rbErr := h.configStore.DeletePlugin(ctx, request.Name); rbErr != nil {
|
||||
logger.Error("failed to rollback plugin creation: %v", rbErr)
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin created in database but failed to load: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
plugin, err := h.configStore.GetPlugin(ctx, request.Name)
|
||||
if err != nil {
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugin")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusCreated)
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Plugin created successfully",
|
||||
"plugin": h.buildPluginResponse(ctx, plugin),
|
||||
})
|
||||
}
|
||||
|
||||
// updatePlugin updates an existing plugin
|
||||
func (h *PluginsHandler) updatePlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, 400, "Plugins update is not supported when configstore is disabled")
|
||||
return
|
||||
}
|
||||
// Safely validate the "name" parameter
|
||||
nameValue := ctx.UserValue("name")
|
||||
if nameValue == nil {
|
||||
logger.Warn("missing required 'name' parameter in update plugin request")
|
||||
SendError(ctx, 400, "Missing required 'name' parameter")
|
||||
return
|
||||
}
|
||||
|
||||
name, ok := nameValue.(string)
|
||||
if !ok {
|
||||
logger.Warn("invalid 'name' parameter type in update plugin request, expected string but got %T", nameValue)
|
||||
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
logger.Warn("empty 'name' parameter provided in update plugin request")
|
||||
SendError(ctx, 400, "Empty 'name' parameter not allowed")
|
||||
return
|
||||
}
|
||||
var plugin *configstoreTables.TablePlugin
|
||||
var err error
|
||||
// Check if plugin exists
|
||||
_, err = h.configStore.GetPlugin(ctx, name)
|
||||
if err != nil {
|
||||
// If doesn't exist, create it
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
plugin = &configstoreTables.TablePlugin{
|
||||
Name: name,
|
||||
Enabled: false,
|
||||
Config: map[string]any{},
|
||||
Path: nil,
|
||||
IsCustom: false,
|
||||
}
|
||||
if err := h.configStore.CreatePlugin(ctx, plugin); err != nil {
|
||||
logger.Error("failed to create plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to create plugin")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to update plugin")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshalling the request body
|
||||
var request UpdatePluginRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &request); err != nil {
|
||||
logger.Error("failed to unmarshal update plugin request: %v", err)
|
||||
SendError(ctx, 400, "Invalid request body")
|
||||
return
|
||||
}
|
||||
// Validate placement value
|
||||
if request.Placement != nil && *request.Placement != "" &&
|
||||
*request.Placement != schemas.PluginPlacementPreBuiltin &&
|
||||
*request.Placement != schemas.PluginPlacementPostBuiltin {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Invalid placement value. Must be 'pre_builtin' or 'post_builtin'")
|
||||
return
|
||||
}
|
||||
if request.Placement != nil && *request.Placement == "" {
|
||||
request.Placement = nil
|
||||
}
|
||||
// Normalize empty path to nil (treat empty string as built-in plugin)
|
||||
if request.Path != nil && *request.Path == "" {
|
||||
request.Path = nil
|
||||
}
|
||||
// Determine if this is a built-in plugin
|
||||
isBuiltin := lib.IsBuiltinPlugin(name)
|
||||
// Built-in plugins should not have a path
|
||||
if isBuiltin && request.Path != nil {
|
||||
request.Path = nil
|
||||
}
|
||||
// Updating the plugin
|
||||
if err := h.configStore.UpdatePlugin(ctx, &configstoreTables.TablePlugin{
|
||||
Name: name,
|
||||
Enabled: request.Enabled,
|
||||
Config: request.Config,
|
||||
Path: request.Path,
|
||||
IsCustom: !isBuiltin,
|
||||
Placement: request.Placement,
|
||||
Order: request.Order,
|
||||
}); err != nil {
|
||||
logger.Error("failed to update plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to update plugin")
|
||||
return
|
||||
}
|
||||
plugin, err = h.configStore.GetPlugin(ctx, name)
|
||||
if err != nil {
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
|
||||
return
|
||||
}
|
||||
logger.Error("failed to get plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to retrieve plugin")
|
||||
return
|
||||
}
|
||||
// We reload the plugin if its enabled, otherwise we stop it
|
||||
if request.Enabled {
|
||||
if err := h.pluginsLoader.ReloadPlugin(ctx, name, request.Path, request.Config, request.Placement, request.Order); err != nil {
|
||||
logger.Error("failed to load plugin: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin updated in database but failed to load: %v", err))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ctx.SetUserValue(PluginDisabledKey, true)
|
||||
if err := h.pluginsLoader.RemovePlugin(ctx, name); err != nil {
|
||||
if !errors.Is(err, plugins.ErrPluginNotFound) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin updated in database but failed to stop: %v", err))
|
||||
return
|
||||
}
|
||||
// If not found then we don't need to do anything
|
||||
}
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"message": "Plugin updated successfully",
|
||||
"plugin": h.buildPluginResponse(ctx, plugin),
|
||||
})
|
||||
}
|
||||
|
||||
// deletePlugin deletes an existing plugin
|
||||
func (h *PluginsHandler) deletePlugin(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, 400, "Plugins deletion is not supported when configstore is disabled")
|
||||
return
|
||||
}
|
||||
// Safely validate the "name" parameter
|
||||
nameValue := ctx.UserValue("name")
|
||||
if nameValue == nil {
|
||||
logger.Warn("missing required 'name' parameter in delete plugin request")
|
||||
SendError(ctx, 400, "Missing required 'name' parameter")
|
||||
return
|
||||
}
|
||||
|
||||
name, ok := nameValue.(string)
|
||||
if !ok {
|
||||
logger.Warn("invalid 'name' parameter type in delete plugin request, expected string but got %T", nameValue)
|
||||
SendError(ctx, 400, "Invalid 'name' parameter type, expected string")
|
||||
return
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
logger.Warn("empty 'name' parameter provided in delete plugin request")
|
||||
SendError(ctx, 400, "Empty 'name' parameter not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.configStore.DeletePlugin(ctx, name); err != nil {
|
||||
if errors.Is(err, configstore.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, "Plugin not found")
|
||||
return
|
||||
}
|
||||
logger.Error("failed to delete plugin: %v", err)
|
||||
SendError(ctx, 500, "Failed to delete plugin")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.pluginsLoader.RemovePlugin(ctx, name); err != nil {
|
||||
if !errors.Is(err, plugins.ErrPluginNotFound) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Plugin deleted in database but failed to stop: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, map[string]interface{}{
|
||||
"message": "Plugin deleted successfully",
|
||||
})
|
||||
}
|
||||
149
transports/bifrost-http/handlers/pricing_override_test.go
Normal file
149
transports/bifrost-http/handlers/pricing_override_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type pricingOverrideTestGovernanceManager struct{}
|
||||
|
||||
func (pricingOverrideTestGovernanceManager) GetGovernanceData(ctx context.Context) *governance.GovernanceData {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadVirtualKey(context.Context, string) (*configstoreTables.TableVirtualKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveVirtualKey(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadTeam(context.Context, string) (*configstoreTables.TableTeam, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveTeam(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadCustomer(context.Context, string) (*configstoreTables.TableCustomer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveCustomer(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadModelConfig(context.Context, string) (*configstoreTables.TableModelConfig, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveModelConfig(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadProvider(context.Context, schemas.ModelProvider) (*configstoreTables.TableProvider, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveProvider(context.Context, schemas.ModelProvider) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) ReloadRoutingRule(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) RemoveRoutingRule(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) UpsertPricingOverride(context.Context, *configstoreTables.TablePricingOverride) error {
|
||||
return nil
|
||||
}
|
||||
func (pricingOverrideTestGovernanceManager) DeletePricingOverride(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupPricingOverrideHandlerStore(t *testing.T) configstore.ConfigStore {
|
||||
t.Helper()
|
||||
|
||||
dbPath := t.TempDir() + "/config.db"
|
||||
store, err := configstore.NewConfigStore(context.Background(), &configstore.Config{
|
||||
Enabled: true,
|
||||
Type: configstore.ConfigStoreTypeSQLite,
|
||||
Config: &configstore.SQLiteConfig{
|
||||
Path: dbPath,
|
||||
},
|
||||
}, &mockLogger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(dbPath)
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func newTestRequestCtx(body string) *fasthttp.RequestCtx {
|
||||
var req fasthttp.Request
|
||||
req.SetBodyString(body)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}, nil)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestUpdatePricingOverride_ReplacesFullBody(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
store := setupPricingOverrideHandlerStore(t)
|
||||
handler := &GovernanceHandler{
|
||||
configStore: store,
|
||||
governanceManager: pricingOverrideTestGovernanceManager{},
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
override := configstoreTables.TablePricingOverride{
|
||||
ID: "override-1",
|
||||
Name: "Original",
|
||||
ScopeKind: string(modelcatalog.ScopeKindGlobal),
|
||||
MatchType: string(modelcatalog.MatchTypeExact),
|
||||
Pattern: "gpt-4.1",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
PricingPatchJSON: `{"input_cost_per_token":1,"output_cost_per_token":2}`,
|
||||
RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest},
|
||||
}
|
||||
require.NoError(t, store.CreatePricingOverride(context.Background(), &override))
|
||||
|
||||
// Patch replaces in full: send only input_cost_per_token.
|
||||
// output_cost_per_token must be absent from the stored patch afterwards,
|
||||
// confirming full-replace (not merge) semantics.
|
||||
body := `{
|
||||
"name":"Updated",
|
||||
"scope_kind":"global",
|
||||
"match_type":"exact",
|
||||
"pattern":"gpt-4.1",
|
||||
"request_types":["chat_completion"],
|
||||
"patch":{"input_cost_per_token":1.5}
|
||||
}`
|
||||
ctx := newTestRequestCtx(body)
|
||||
ctx.SetUserValue("id", override.ID)
|
||||
|
||||
handler.updatePricingOverride(ctx)
|
||||
|
||||
require.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
|
||||
stored, err := store.GetPricingOverrideByID(context.Background(), override.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", stored.Name)
|
||||
|
||||
var patch modelcatalog.PricingOptions
|
||||
require.NoError(t, json.Unmarshal([]byte(stored.PricingPatchJSON), &patch))
|
||||
// Sent field must reflect the new value.
|
||||
require.NotNil(t, patch.InputCostPerToken)
|
||||
assert.Equal(t, 1.5, *patch.InputCostPerToken)
|
||||
// Omitted field must be cleared — patch is always fully replaced, not merged.
|
||||
assert.Nil(t, patch.OutputCostPerToken)
|
||||
assert.Empty(t, stored.ConfigHash)
|
||||
}
|
||||
1121
transports/bifrost-http/handlers/prompts.go
Normal file
1121
transports/bifrost-http/handlers/prompts.go
Normal file
File diff suppressed because it is too large
Load Diff
495
transports/bifrost-http/handlers/provider_keys.go
Normal file
495
transports/bifrost-http/handlers/provider_keys.go
Normal file
@@ -0,0 +1,495 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/google/uuid"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ListProviderKeysResponse represents the response for listing keys for a provider.
|
||||
type ListProviderKeysResponse struct {
|
||||
Keys []schemas.Key `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) listProviderKeys(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := h.inMemoryStore.GetProviderKeysRedacted(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider keys: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, ListProviderKeysResponse{Keys: keys, Total: len(keys)})
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) getProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := getKeyIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, key)
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var key schemas.Key
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &key); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Cannot add keys to a keyless provider")
|
||||
return
|
||||
}
|
||||
|
||||
baseProvider := provider
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
|
||||
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
|
||||
}
|
||||
|
||||
if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && key.Value.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty")
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateProviderKeyURL(provider, key); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := key.BlacklistedModels.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := key.Aliases.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if key.ID == "" {
|
||||
key.ID = uuid.NewString()
|
||||
}
|
||||
if key.Enabled == nil {
|
||||
key.Enabled = bifrost.Ptr(true)
|
||||
}
|
||||
|
||||
if err := h.inMemoryStore.AddProviderKey(ctx, provider, key); err != nil {
|
||||
logger.Warn("Failed to create key for provider %s: %v", provider, err)
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
if errors.Is(err, lib.ErrAlreadyExists) {
|
||||
SendError(ctx, fasthttp.StatusConflict, err.Error())
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
|
||||
logger.Warn("Model discovery failed for provider %s after key create: %v", provider, err)
|
||||
}
|
||||
|
||||
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, key.ID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get created provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, redactedKey)
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := getKeyIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var updateKey schemas.Key
|
||||
if err := sonic.Unmarshal(ctx.PostBody(), &updateKey); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Cannot update keys on a keyless provider")
|
||||
return
|
||||
}
|
||||
|
||||
oldRawKey, err := h.inMemoryStore.GetProviderKeyRaw(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
oldRedactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
updateKey.ID = keyID
|
||||
mergedKey := h.mergeUpdatedKey(*oldRawKey, *oldRedactedKey, updateKey)
|
||||
|
||||
baseProvider := provider
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
|
||||
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
|
||||
}
|
||||
|
||||
if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && mergedKey.Value.GetValue() == "" {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty")
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergedKey.BlacklistedModels.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergedKey.Aliases.Validate(); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateProviderKeyURL(provider, mergedKey); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.inMemoryStore.UpdateProviderKey(ctx, provider, keyID, mergedKey); err != nil {
|
||||
logger.Warn("Failed to update key %s for provider %s: %v", keyID, provider, err)
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
|
||||
logger.Warn("Model discovery failed for provider %s after key update: %v", provider, err)
|
||||
}
|
||||
|
||||
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get updated provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
SendJSON(ctx, redactedKey)
|
||||
}
|
||||
|
||||
func (h *ProviderHandler) deleteProviderKey(ctx *fasthttp.RequestCtx) {
|
||||
provider, err := getProviderFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
keyID, err := getKeyIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, "Cannot delete keys on a keyless provider")
|
||||
return
|
||||
}
|
||||
|
||||
redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.inMemoryStore.RemoveProviderKey(ctx, provider, keyID); err != nil {
|
||||
logger.Warn("Failed to delete key %s for provider %s: %v", keyID, provider, err)
|
||||
if errors.Is(err, lib.ErrNotFound) {
|
||||
SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err))
|
||||
return
|
||||
}
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to delete provider key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil {
|
||||
logger.Warn("Model discovery failed for provider %s after key delete: %v", provider, err)
|
||||
}
|
||||
|
||||
SendJSON(ctx, redactedKey)
|
||||
}
|
||||
|
||||
// mergeUpdatedKey merges an updated key with the old raw/redacted versions,
|
||||
// preserving real values for fields that were sent back in redacted form.
|
||||
func (h *ProviderHandler) mergeUpdatedKey(oldRawKey, oldRedactedKey, updateKey schemas.Key) schemas.Key {
|
||||
mergedKey := updateKey
|
||||
|
||||
if updateKey.Value.IsRedacted() && updateKey.Value.Equals(&oldRedactedKey.Value) {
|
||||
mergedKey.Value = oldRawKey.Value
|
||||
}
|
||||
|
||||
if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil {
|
||||
if updateKey.AzureKeyConfig.Endpoint.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) {
|
||||
mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint
|
||||
}
|
||||
if updateKey.AzureKeyConfig.APIVersion != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.APIVersion != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.APIVersion.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) {
|
||||
mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion
|
||||
}
|
||||
if updateKey.AzureKeyConfig.ClientID != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.ClientID != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.ClientID.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) {
|
||||
mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID
|
||||
}
|
||||
if updateKey.AzureKeyConfig.ClientSecret != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.ClientSecret != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.ClientSecret.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) {
|
||||
mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret
|
||||
}
|
||||
if updateKey.AzureKeyConfig.TenantID != nil &&
|
||||
oldRedactedKey.AzureKeyConfig.TenantID != nil &&
|
||||
oldRawKey.AzureKeyConfig != nil &&
|
||||
updateKey.AzureKeyConfig.TenantID.IsRedacted() &&
|
||||
updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) {
|
||||
mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil {
|
||||
if updateKey.VertexKeyConfig.ProjectID.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) {
|
||||
mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID
|
||||
}
|
||||
if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) {
|
||||
mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber
|
||||
}
|
||||
if updateKey.VertexKeyConfig.Region.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) {
|
||||
mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region
|
||||
}
|
||||
if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() &&
|
||||
updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) {
|
||||
mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil {
|
||||
if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) {
|
||||
mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) {
|
||||
mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.SessionToken != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.SessionToken != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.SessionToken.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) {
|
||||
mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.Region != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.Region != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.Region.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) {
|
||||
mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.ARN != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.ARN != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.ARN.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) {
|
||||
mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.RoleARN != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.RoleARN != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.RoleARN.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) {
|
||||
mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.ExternalID != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.ExternalID != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.ExternalID.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) {
|
||||
mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID
|
||||
}
|
||||
if updateKey.BedrockKeyConfig.RoleSessionName != nil &&
|
||||
oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil &&
|
||||
oldRawKey.BedrockKeyConfig != nil &&
|
||||
updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() &&
|
||||
updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) {
|
||||
mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil {
|
||||
if updateKey.VLLMKeyConfig.URL.IsRedacted() &&
|
||||
updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) {
|
||||
mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL
|
||||
}
|
||||
}
|
||||
|
||||
// ReplicateKeyConfig has no sensitive fields — pass through as-is
|
||||
if updateKey.ReplicateKeyConfig == nil && oldRawKey.ReplicateKeyConfig != nil {
|
||||
mergedKey.ReplicateKeyConfig = oldRawKey.ReplicateKeyConfig
|
||||
}
|
||||
|
||||
if updateKey.OllamaKeyConfig != nil && oldRedactedKey.OllamaKeyConfig != nil && oldRawKey.OllamaKeyConfig != nil {
|
||||
if updateKey.OllamaKeyConfig.URL.IsRedacted() &&
|
||||
updateKey.OllamaKeyConfig.URL.Equals(&oldRedactedKey.OllamaKeyConfig.URL) {
|
||||
mergedKey.OllamaKeyConfig.URL = oldRawKey.OllamaKeyConfig.URL
|
||||
}
|
||||
}
|
||||
|
||||
if updateKey.SGLKeyConfig != nil && oldRedactedKey.SGLKeyConfig != nil && oldRawKey.SGLKeyConfig != nil {
|
||||
if updateKey.SGLKeyConfig.URL.IsRedacted() &&
|
||||
updateKey.SGLKeyConfig.URL.Equals(&oldRedactedKey.SGLKeyConfig.URL) {
|
||||
mergedKey.SGLKeyConfig.URL = oldRawKey.SGLKeyConfig.URL
|
||||
}
|
||||
}
|
||||
|
||||
mergedKey.ConfigHash = oldRawKey.ConfigHash
|
||||
mergedKey.Status = oldRawKey.Status
|
||||
|
||||
return mergedKey
|
||||
}
|
||||
|
||||
func getKeyIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) {
|
||||
keyValue := ctx.UserValue("key_id")
|
||||
if keyValue == nil {
|
||||
return "", fmt.Errorf("missing key_id parameter")
|
||||
}
|
||||
|
||||
keyID, ok := keyValue.(string)
|
||||
if !ok || keyID == "" {
|
||||
return "", fmt.Errorf("invalid key_id parameter")
|
||||
}
|
||||
|
||||
decoded, err := url.PathUnescape(keyID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid key_id parameter encoding: %v", err)
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// validateProviderKeyURL checks that Ollama/SGL keys have a server URL configured.
|
||||
func validateProviderKeyURL(provider schemas.ModelProvider, key schemas.Key) error {
|
||||
switch provider {
|
||||
case schemas.Ollama:
|
||||
if key.OllamaKeyConfig == nil || !key.OllamaKeyConfig.URL.IsDefined() {
|
||||
return fmt.Errorf("ollama_key_config.url is required for Ollama keys")
|
||||
}
|
||||
case schemas.SGL:
|
||||
if key.SGLKeyConfig == nil || !key.SGLKeyConfig.URL.IsDefined() {
|
||||
return fmt.Errorf("sgl_key_config.url is required for SGL keys")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
1097
transports/bifrost-http/handlers/providers.go
Normal file
1097
transports/bifrost-http/handlers/providers.go
Normal file
File diff suppressed because it is too large
Load Diff
550
transports/bifrost-http/handlers/providers_test.go
Normal file
550
transports/bifrost-http/handlers/providers_test.go
Normal file
@@ -0,0 +1,550 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/modelcatalog"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockModelsManager returns stable filtered and unfiltered model lists for handler tests.
|
||||
type mockModelsManager struct {
|
||||
filtered map[schemas.ModelProvider][]string
|
||||
unfiltered map[schemas.ModelProvider][]string
|
||||
reloadCalls []schemas.ModelProvider
|
||||
reloadErr error
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) ReloadProvider(_ context.Context, provider schemas.ModelProvider) (*configstoreTables.TableProvider, error) {
|
||||
m.reloadCalls = append(m.reloadCalls, provider)
|
||||
if m.reloadErr != nil {
|
||||
return nil, m.reloadErr
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) RemoveProvider(_ context.Context, _ schemas.ModelProvider) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) GetModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
models := m.filtered[provider]
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockModelsManager) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string {
|
||||
models := m.unfiltered[provider]
|
||||
result := make([]string, len(models))
|
||||
copy(result, models)
|
||||
return result
|
||||
}
|
||||
|
||||
// providerHandlerForTest builds a handler with fixed provider config and model sets.
|
||||
func providerHandlerForTest(provider schemas.ModelProvider, keys []schemas.Key, filtered, unfiltered []string) *ProviderHandler {
|
||||
return &ProviderHandler{
|
||||
inMemoryStore: &lib.Config{
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
provider: {
|
||||
Keys: keys,
|
||||
},
|
||||
},
|
||||
},
|
||||
modelsManager: &mockModelsManager{
|
||||
filtered: map[schemas.ModelProvider][]string{
|
||||
provider: filtered,
|
||||
},
|
||||
unfiltered: map[schemas.ModelProvider][]string{
|
||||
provider: unfiltered,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddProvider_ReloadsRuntimeEvenWhenModelDiscoveryIsSkipped(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
lib.SetLogger(&mockLogger{})
|
||||
|
||||
modelsManager := &mockModelsManager{}
|
||||
h := &ProviderHandler{
|
||||
inMemoryStore: &lib.Config{Providers: map[schemas.ModelProvider]configstore.ProviderConfig{}},
|
||||
modelsManager: modelsManager,
|
||||
}
|
||||
|
||||
body, err := sonic.Marshal(providerCreatePayload{
|
||||
Provider: "mock-openai",
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
BaseProviderType: schemas.OpenAI,
|
||||
IsKeyLess: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetRequestURI("/api/providers")
|
||||
ctx.Request.SetBody(body)
|
||||
|
||||
h.addProvider(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
if len(modelsManager.reloadCalls) != 1 || modelsManager.reloadCalls[0] != "mock-openai" {
|
||||
t.Fatalf("expected provider reload for mock-openai, got %#v", modelsManager.reloadCalls)
|
||||
}
|
||||
if _, exists := h.inMemoryStore.Providers["mock-openai"]; !exists {
|
||||
t.Fatalf("expected provider to be added to in-memory store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddProvider_ReturnsErrorWhenRuntimeReloadFails(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
lib.SetLogger(&mockLogger{})
|
||||
|
||||
modelsManager := &mockModelsManager{reloadErr: context.DeadlineExceeded}
|
||||
h := &ProviderHandler{
|
||||
inMemoryStore: &lib.Config{Providers: map[schemas.ModelProvider]configstore.ProviderConfig{}},
|
||||
modelsManager: modelsManager,
|
||||
}
|
||||
|
||||
body, err := sonic.Marshal(providerCreatePayload{
|
||||
Provider: "mock-openai",
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{
|
||||
BaseProviderType: schemas.OpenAI,
|
||||
IsKeyLess: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetRequestURI("/api/providers")
|
||||
ctx.Request.SetBody(body)
|
||||
|
||||
h.addProvider(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
if len(modelsManager.reloadCalls) != 1 || modelsManager.reloadCalls[0] != "mock-openai" {
|
||||
t.Fatalf("expected single provider reload for mock-openai, got %#v", modelsManager.reloadCalls)
|
||||
}
|
||||
var bifrostErr schemas.BifrostError
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v", err)
|
||||
}
|
||||
if bifrostErr.Error == nil || bifrostErr.Error.Message == "" {
|
||||
t.Fatalf("expected error message in response, got %#v", bifrostErr)
|
||||
}
|
||||
if bifrostErr.Error.Message != "Failed to initialize provider after add: context deadline exceeded" {
|
||||
t.Fatalf("unexpected error message: %q", bifrostErr.Error.Message)
|
||||
}
|
||||
if _, exists := h.inMemoryStore.Providers["mock-openai"]; exists {
|
||||
t.Fatalf("expected provider rollback after reload failure")
|
||||
}
|
||||
}
|
||||
|
||||
// boolPtr keeps pointer-valued key fixtures inline without pulling in pointer helpers.
|
||||
func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestListModels_UnknownKeysDoNotFilter(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=missing")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total=2, got %d", resp.Total)
|
||||
}
|
||||
if len(resp.Models) != 2 {
|
||||
t.Fatalf("expected all models to be returned, got %#v", resp.Models)
|
||||
}
|
||||
for _, model := range resp.Models {
|
||||
if len(model.AccessibleByKeys) != 0 {
|
||||
t.Fatalf("expected no accessible_by_keys annotations, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_ReturnsExactAccessibleByKeysAndSkipsDisabledKeys(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-a", Models: []string{"gpt-4o"}},
|
||||
{ID: "key-b", Models: []string{"gpt-4o", "gpt-4o-mini"}},
|
||||
{ID: "key-disabled", Enabled: boolPtr(false)},
|
||||
},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-a,key-b,key-disabled")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total=2, got %d", resp.Total)
|
||||
}
|
||||
|
||||
got := map[string][]string{}
|
||||
for _, model := range resp.Models {
|
||||
got[model.Name] = model.AccessibleByKeys
|
||||
}
|
||||
|
||||
if len(got["gpt-4o"]) != 2 || got["gpt-4o"][0] != "key-a" || got["gpt-4o"][1] != "key-b" {
|
||||
t.Fatalf("expected gpt-4o to be accessible by [key-a key-b], got %#v", got["gpt-4o"])
|
||||
}
|
||||
if len(got["gpt-4o-mini"]) != 1 || got["gpt-4o-mini"][0] != "key-b" {
|
||||
t.Fatalf("expected gpt-4o-mini to be accessible by [key-b], got %#v", got["gpt-4o-mini"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_AppliesQueryAndLimitAfterFiltering(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&query=gpt&limit=1")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 {
|
||||
t.Fatalf("expected total=2 after query filtering, got %d", resp.Total)
|
||||
}
|
||||
if len(resp.Models) != 1 {
|
||||
t.Fatalf("expected limit to truncate response to 1 model, got %#v", resp.Models)
|
||||
}
|
||||
if resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected first filtered model to be gpt-4o, got %#v", resp.Models[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_UnfilteredIgnoresKeys(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-b&unfiltered=true")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected both unfiltered models, got %#v", resp.Models)
|
||||
}
|
||||
|
||||
for _, model := range resp.Models {
|
||||
if len(model.AccessibleByKeys) != 0 {
|
||||
t.Fatalf("expected no accessible_by_keys when unfiltered bypasses key filtering, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_UnfilteredWithoutKeysReturnsAllUnfilteredModels(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&unfiltered=true")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected both unfiltered models, got %#v", resp.Models)
|
||||
}
|
||||
|
||||
for _, model := range resp.Models {
|
||||
if len(model.AccessibleByKeys) != 0 {
|
||||
t.Fatalf("expected no accessible_by_keys when no key filter is requested, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_ErrorsWhenModelCatalogUnavailable(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o"},
|
||||
)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_UnknownKeysDoNotFilter(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a"}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=missing")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected all models when keys are unknown, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_SkipsUnknownKeysAndFiltersWithValid(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{{ID: "key-a", Models: []string{"gpt-4o"}}},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-a,missing")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 || len(resp.Models) != 1 {
|
||||
t.Fatalf("expected 1 model filtered by valid key, got %#v", resp.Models)
|
||||
}
|
||||
if resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected gpt-4o, got %s", resp.Models[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_SkipsDisabledKeysAndFiltersWithValid(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-a", Models: []string{"gpt-4o"}},
|
||||
{ID: "key-disabled", Enabled: boolPtr(false)},
|
||||
},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-a,key-disabled")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 || len(resp.Models) != 1 {
|
||||
t.Fatalf("expected 1 model filtered by valid key, got %#v", resp.Models)
|
||||
}
|
||||
if resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected gpt-4o, got %s", resp.Models[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelDetails_UnfilteredIgnoresKeys(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-b", Models: []string{"gpt-4o-mini"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o", "gpt-4o-mini"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = &modelcatalog.ModelCatalog{}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models/details?provider=openai&keys=key-b&unfiltered=true")
|
||||
|
||||
h.listModelDetails(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelDetailsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 2 || len(resp.Models) != 2 {
|
||||
t.Fatalf("expected all unfiltered models when unfiltered=true, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_UsesCatalogAwareAliasMatchingForKeyAllowlist(t *testing.T) {
|
||||
SetLogger(&mockLogger{})
|
||||
|
||||
h := providerHandlerForTest(
|
||||
schemas.OpenAI,
|
||||
[]schemas.Key{
|
||||
{ID: "key-a", Models: []string{"gpt-4o-2024-08-06"}},
|
||||
},
|
||||
[]string{"gpt-4o"},
|
||||
[]string{"gpt-4o"},
|
||||
)
|
||||
h.inMemoryStore.ModelCatalog = modelcatalog.NewTestCatalog(map[string]string{
|
||||
"gpt-4o-2024-08-06": "gpt-4o",
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/api/models?provider=openai&keys=key-a")
|
||||
|
||||
h.listModels(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", ctx.Response.StatusCode(), string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
var resp ListModelsResponse
|
||||
if err := json.Unmarshal(ctx.Response.Body(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Total != 1 || len(resp.Models) != 1 || resp.Models[0].Name != "gpt-4o" {
|
||||
t.Fatalf("expected gpt-4o to be matched through alias allowlist, got %#v", resp.Models)
|
||||
}
|
||||
}
|
||||
419
transports/bifrost-http/handlers/realtime_client_secrets.go
Normal file
419
transports/bifrost-http/handlers/realtime_client_secrets.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// RealtimeClientSecretsHandler exposes OpenAI-compatible HTTP routes for
|
||||
// minting short-lived Realtime client secrets.
|
||||
type RealtimeClientSecretsHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
routeSpecs map[string]schemas.RealtimeSessionRoute
|
||||
}
|
||||
|
||||
func NewRealtimeClientSecretsHandler(client *bifrost.Bifrost, config *lib.Config) *RealtimeClientSecretsHandler {
|
||||
return &RealtimeClientSecretsHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
routeSpecs: make(map[string]schemas.RealtimeSessionRoute),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleRequest, middlewares...)
|
||||
for _, route := range h.realtimeSessionRoutes() {
|
||||
h.routeSpecs[route.Path] = route
|
||||
r.POST(route.Path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) findGovernancePlugin() governance.BaseGovernancePlugin {
|
||||
basePlugins := h.config.BasePlugins.Load()
|
||||
if basePlugins == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, plugin := range *basePlugins {
|
||||
if governancePlugin, ok := plugin.(governance.BaseGovernancePlugin); ok {
|
||||
return governancePlugin
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) handleRequest(ctx *fasthttp.RequestCtx) {
|
||||
if !isJSONContentType(string(ctx.Request.Header.ContentType())) {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"Content-Type must be application/json",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
body := append([]byte(nil), ctx.Request.Body()...)
|
||||
route, ok := h.routeSpecs[string(ctx.Path())]
|
||||
if !ok {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusNotFound,
|
||||
"invalid_request_error",
|
||||
"unsupported realtime client secret route",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
providerKey, model, normalizedBody, err := resolveRealtimeClientSecretTarget(route, body)
|
||||
if err != nil {
|
||||
SendBifrostError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := lib.ConvertToBifrostContext(
|
||||
ctx,
|
||||
h.handlerStore.ShouldAllowDirectKeys(),
|
||||
h.config.GetHeaderMatcher(),
|
||||
h.config.GetMCPHeaderCombinedAllowlist(),
|
||||
)
|
||||
defer cancel()
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
if route.DefaultProvider == schemas.OpenAI {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
|
||||
}
|
||||
if governanceUserID, ok := ctx.UserValue(schemas.BifrostContextKeyUserID).(string); ok && governanceUserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, governanceUserID)
|
||||
}
|
||||
if userName, ok := ctx.UserValue(schemas.BifrostContextKeyUserName).(string); ok && userName != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserName, userName)
|
||||
}
|
||||
if bifrostErr := h.evaluateMintingGovernance(bifrostCtx, providerKey, model); bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
provider := h.client.GetProviderByKey(providerKey)
|
||||
if provider == nil {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"provider not found: "+string(providerKey),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
key, keyErr := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
|
||||
if keyErr != nil {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
keyErr.Error(),
|
||||
keyErr,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve model aliases now that the key is selected so the forwarded body
|
||||
// carries the provider's canonical model, matching wsrealtime/webrtc flows.
|
||||
if resolved := key.Aliases.Resolve(model); resolved != "" && resolved != model {
|
||||
model = resolved
|
||||
reparsed, parseErr := schemas.ParseRealtimeClientSecretBody(normalizedBody)
|
||||
if parseErr != nil {
|
||||
SendBifrostError(ctx, parseErr)
|
||||
return
|
||||
}
|
||||
rewritten, normalizeErr := normalizeRealtimeClientSecretBody(reparsed, model)
|
||||
if normalizeErr != nil {
|
||||
SendBifrostError(ctx, normalizeErr)
|
||||
return
|
||||
}
|
||||
normalizedBody = rewritten
|
||||
}
|
||||
|
||||
sessionProvider, ok := provider.(schemas.RealtimeSessionProvider)
|
||||
if !ok {
|
||||
SendBifrostError(ctx, realtimeSessionNotSupportedError(providerKey, provider))
|
||||
return
|
||||
}
|
||||
|
||||
resp, bifrostErr := sessionProvider.CreateRealtimeClientSecret(bifrostCtx, key, route.EndpointType, normalizedBody)
|
||||
if bifrostErr != nil {
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
return
|
||||
}
|
||||
cacheRealtimeEphemeralKeyMapping(
|
||||
h.handlerStore.GetKVStore(),
|
||||
resp.Body,
|
||||
key.ID,
|
||||
bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey),
|
||||
)
|
||||
|
||||
writeRealtimeClientSecretResponse(ctx, resp)
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) evaluateMintingGovernance(
|
||||
bifrostCtx *schemas.BifrostContext,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
) *schemas.BifrostError {
|
||||
governancePlugin := h.findGovernancePlugin()
|
||||
if governancePlugin == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, bifrostErr := governancePlugin.EvaluateGovernanceRequest(bifrostCtx, &governance.EvaluationRequest{
|
||||
VirtualKey: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey),
|
||||
Provider: providerKey,
|
||||
Model: model,
|
||||
UserID: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyUserID),
|
||||
}, schemas.RealtimeRequest)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func (h *RealtimeClientSecretsHandler) realtimeSessionRoutes() []schemas.RealtimeSessionRoute {
|
||||
routes := []schemas.RealtimeSessionRoute{
|
||||
{
|
||||
Path: "/v1/realtime/client_secrets",
|
||||
EndpointType: schemas.RealtimeSessionEndpointClientSecrets,
|
||||
},
|
||||
{
|
||||
Path: "/v1/realtime/sessions",
|
||||
EndpointType: schemas.RealtimeSessionEndpointSessions,
|
||||
},
|
||||
}
|
||||
|
||||
for _, path := range integrations.OpenAIRealtimeClientSecretPaths("/openai") {
|
||||
endpointType := schemas.RealtimeSessionEndpointClientSecrets
|
||||
if strings.HasSuffix(path, "/realtime/sessions") {
|
||||
endpointType = schemas.RealtimeSessionEndpointSessions
|
||||
}
|
||||
routes = append(routes, schemas.RealtimeSessionRoute{
|
||||
Path: path,
|
||||
EndpointType: endpointType,
|
||||
DefaultProvider: schemas.OpenAI,
|
||||
})
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func resolveRealtimeClientSecretTarget(route schemas.RealtimeSessionRoute, body []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) {
|
||||
root, err := schemas.ParseRealtimeClientSecretBody(body)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
|
||||
rawModel, err := schemas.ExtractRealtimeClientSecretModel(root)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
|
||||
defaultProvider := route.DefaultProvider
|
||||
providerKey, model := schemas.ParseModelString(rawModel, defaultProvider)
|
||||
if defaultProvider == "" && providerKey == "" {
|
||||
return "", "", nil, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"session.model must use provider/model on /v1 realtime client secret routes",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
if providerKey == "" || model == "" {
|
||||
return "", "", nil, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
"session.model is required",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Normalize the forwarded body so the upstream provider sees the bare model
|
||||
// (strip provider prefix). Mirrors resolveRealtimeSDPTarget normalization.
|
||||
normalizedBody, normalizeErr := normalizeRealtimeClientSecretBody(root, model)
|
||||
if normalizeErr != nil {
|
||||
return "", "", nil, normalizeErr
|
||||
}
|
||||
|
||||
return providerKey, model, normalizedBody, nil
|
||||
}
|
||||
|
||||
func normalizeRealtimeClientSecretBody(root map[string]json.RawMessage, bareModel string) ([]byte, *schemas.BifrostError) {
|
||||
normalizedModel, marshalErr := json.Marshal(bareModel)
|
||||
if marshalErr != nil {
|
||||
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
|
||||
}
|
||||
|
||||
// Normalize session.model if present
|
||||
if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 {
|
||||
var session map[string]json.RawMessage
|
||||
if err := json.Unmarshal(sessionJSON, &session); err == nil {
|
||||
if _, hasModel := session["model"]; hasModel {
|
||||
session["model"] = normalizedModel
|
||||
rewritten, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode session", err)
|
||||
}
|
||||
root["session"] = rewritten
|
||||
}
|
||||
}
|
||||
}
|
||||
// Normalize top-level model if present
|
||||
if _, ok := root["model"]; ok {
|
||||
root["model"] = normalizedModel
|
||||
}
|
||||
|
||||
normalized, marshalErr := json.Marshal(root)
|
||||
if marshalErr != nil {
|
||||
return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode body", marshalErr)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
const realtimeEphemeralKeyMappingPrefix = "realtime:ephemeral-key:"
|
||||
|
||||
type realtimeEphemeralKeyMapping struct {
|
||||
KeyID string `json:"key_id,omitempty"`
|
||||
VirtualKey string `json:"virtual_key,omitempty"`
|
||||
}
|
||||
|
||||
func cacheRealtimeEphemeralKeyMapping(kv schemas.KVStore, body []byte, keyID string, virtualKey string) {
|
||||
if kv == nil || len(body) == 0 || strings.TrimSpace(keyID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping(body)
|
||||
if !ok || strings.TrimSpace(token) == "" || ttl <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(realtimeEphemeralKeyMapping{
|
||||
KeyID: strings.TrimSpace(keyID),
|
||||
VirtualKey: strings.TrimSpace(virtualKey),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("failed to encode realtime ephemeral key mapping for key_id=%s: %v", keyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := kv.SetWithTTL(buildRealtimeEphemeralKeyMappingKey(token), payload, ttl); err != nil {
|
||||
logger.Warn("failed to cache realtime ephemeral key mapping for key_id=%s: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func parseRealtimeEphemeralKeyMapping(body []byte) (string, time.Duration, bool) {
|
||||
var root map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &root); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
var clientSecret struct {
|
||||
Value string `json:"value"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OpenAI client_secrets responses expose the ephemeral token at the top level.
|
||||
// Keep accepting the nested shape too so the mapping logic stays compatible
|
||||
// with any provider/session endpoint variants that wrap the secret object.
|
||||
if err := json.Unmarshal(body, &clientSecret); err != nil || strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 {
|
||||
clientSecretRaw, ok := root["client_secret"]
|
||||
if !ok || len(clientSecretRaw) == 0 || string(clientSecretRaw) == "null" {
|
||||
return "", 0, false
|
||||
}
|
||||
if err := json.Unmarshal(clientSecretRaw, &clientSecret); err != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
ttl := time.Until(time.Unix(clientSecret.ExpiresAt, 0))
|
||||
if ttl <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
return clientSecret.Value, ttl, true
|
||||
}
|
||||
|
||||
func buildRealtimeEphemeralKeyMappingKey(token string) string {
|
||||
return realtimeEphemeralKeyMappingPrefix + strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
func realtimeSessionNotSupportedError(providerKey schemas.ModelProvider, provider schemas.Provider) *schemas.BifrostError {
|
||||
if rtProvider, ok := provider.(schemas.RealtimeProvider); ok && rtProvider.SupportsRealtimeAPI() {
|
||||
return newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
fmt.Sprintf("provider %s supports realtime websocket connections but not realtime client secret creation", providerKey),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusBadRequest,
|
||||
"invalid_request_error",
|
||||
fmt.Sprintf("provider %s does not support realtime client secret creation", providerKey),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newRealtimeClientSecretHandlerError(status int, errorType, message string, err error) *schemas.BifrostError {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(status),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errorType),
|
||||
Message: message,
|
||||
Error: err,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func writeRealtimeClientSecretResponse(ctx *fasthttp.RequestCtx, resp *schemas.BifrostPassthroughResponse) {
|
||||
if resp == nil {
|
||||
SendBifrostError(ctx, newRealtimeClientSecretHandlerError(
|
||||
fasthttp.StatusInternalServerError,
|
||||
"server_error",
|
||||
"provider returned an empty realtime client secret response",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range resp.Headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
ctx.SetContentType("application/json")
|
||||
}
|
||||
ctx.SetStatusCode(resp.StatusCode)
|
||||
ctx.SetBody(resp.Body)
|
||||
}
|
||||
|
||||
func isJSONContentType(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mediaType = strings.ToLower(mediaType)
|
||||
return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
|
||||
}
|
||||
414
transports/bifrost-http/handlers/realtime_client_secrets_test.go
Normal file
414
transports/bifrost-http/handlers/realtime_client_secrets_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestResolveRealtimeClientSecretTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
route schemas.RealtimeSessionRoute
|
||||
body []byte
|
||||
wantProvider schemas.ModelProvider
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "base route with session model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: []byte(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "base route with top level model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
|
||||
body: []byte(`{"model":"openai/gpt-4o-realtime-preview"}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "openai alias uses bare model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
wantProvider: schemas.OpenAI,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "base route rejects bare model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing model",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: []byte(`{"session":{}}`),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotProvider, gotModel, _, err := resolveRealtimeClientSecretTarget(tt.route, tt.body)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("resolveRealtimeClientSecretTarget() error = %v", err)
|
||||
}
|
||||
if gotProvider != tt.wantProvider {
|
||||
t.Fatalf("provider = %q, want %q", gotProvider, tt.wantProvider)
|
||||
}
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeClientSecretTarget_NormalizesModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
route schemas.RealtimeSessionRoute
|
||||
body string
|
||||
wantModel string // bare model expected in normalized body
|
||||
}{
|
||||
{
|
||||
name: "session.model provider prefix stripped",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets},
|
||||
body: `{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "top-level model provider prefix stripped",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions},
|
||||
body: `{"model":"openai/gpt-4o-realtime-preview"}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
{
|
||||
name: "bare model unchanged on alias route",
|
||||
route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI},
|
||||
body: `{"session":{"model":"gpt-4o-realtime-preview"}}`,
|
||||
wantModel: "gpt-4o-realtime-preview",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, normalizedBody, err := resolveRealtimeClientSecretTarget(tt.route, []byte(tt.body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var root map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(normalizedBody, &root); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal normalized body: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
// Check session.model if present
|
||||
if sessionJSON, ok := root["session"]; ok {
|
||||
var session map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(sessionJSON, &session); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal session: %v", unmarshalErr)
|
||||
}
|
||||
if modelJSON, ok := session["model"]; ok {
|
||||
var model string
|
||||
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal session.model: %v", unmarshalErr)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("session.model = %q, want %q", model, tt.wantModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check top-level model if present
|
||||
if modelJSON, ok := root["model"]; ok {
|
||||
var model string
|
||||
if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("model = %q, want %q", model, tt.wantModel)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEphemeralKeyMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
|
||||
"value": "ek_test_123",
|
||||
"expires_at": 4102444800
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected ephemeral mapping to be parsed")
|
||||
}
|
||||
if token != "ek_test_123" {
|
||||
t.Fatalf("token = %q, want %q", token, "ek_test_123")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
t.Fatalf("ttl = %v, want > 0", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEphemeralKeyMapping_NestedFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{
|
||||
"client_secret": {
|
||||
"value": "ek_test_nested",
|
||||
"expires_at": 4102444800
|
||||
}
|
||||
}`))
|
||||
if !ok {
|
||||
t.Fatal("expected nested ephemeral mapping to be parsed")
|
||||
}
|
||||
if token != "ek_test_nested" {
|
||||
t.Fatalf("token = %q, want %q", token, "ek_test_nested")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
t.Fatalf("ttl = %v, want > 0", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRealtimeEphemeralKeyMappingStoresKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
body := []byte(`{
|
||||
"value": "ek_test_456",
|
||||
"expires_at": ` + "4102444800" + `
|
||||
}`)
|
||||
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "sk-bf-test")
|
||||
|
||||
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_456"))
|
||||
if err != nil {
|
||||
t.Fatalf("store.Get() error = %v", err)
|
||||
}
|
||||
value, ok := raw.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("cached value type = %T, want []byte", raw)
|
||||
}
|
||||
var mapping realtimeEphemeralKeyMapping
|
||||
if err := json.Unmarshal(value, &mapping); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if mapping.KeyID != "key_123" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
|
||||
}
|
||||
if mapping.VirtualKey != "sk-bf-test" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheRealtimeEphemeralKeyMappingSkipsExpiredSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
expired := time.Now().Add(-time.Minute).Unix()
|
||||
body := fmt.Appendf(nil, `{
|
||||
"value": "ek_expired",
|
||||
"expires_at": %d
|
||||
}`, expired)
|
||||
cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "")
|
||||
|
||||
if _, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_expired")); err == nil {
|
||||
t.Fatal("expected no cached mapping for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsJSONContentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !isJSONContentType("application/json; charset=utf-8") {
|
||||
t.Fatal("expected application/json content type to pass")
|
||||
}
|
||||
if !isJSONContentType("application/vnd.openai+json") {
|
||||
t.Fatal("expected +json content type to pass")
|
||||
}
|
||||
if isJSONContentType("text/plain") {
|
||||
t.Fatal("expected text/plain content type to fail")
|
||||
}
|
||||
}
|
||||
|
||||
type mockRealtimeMintingGovernancePlugin struct {
|
||||
err *schemas.BifrostError
|
||||
seenUserID string
|
||||
seenVirtualKey string
|
||||
seenProvider schemas.ModelProvider
|
||||
seenModel string
|
||||
evaluateCalls int
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) GetName() string {
|
||||
return governance.PluginName
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *governance.EvaluationRequest, _ schemas.RequestType) (*governance.EvaluationResult, *schemas.BifrostError) {
|
||||
m.evaluateCalls++
|
||||
m.seenUserID = ""
|
||||
m.seenVirtualKey = ""
|
||||
m.seenProvider = ""
|
||||
m.seenModel = ""
|
||||
if evaluationRequest != nil {
|
||||
m.seenUserID = evaluationRequest.UserID
|
||||
m.seenVirtualKey = evaluationRequest.VirtualKey
|
||||
m.seenProvider = evaluationRequest.Provider
|
||||
m.seenModel = evaluationRequest.Model
|
||||
}
|
||||
if ctx != nil && m.seenVirtualKey == "" {
|
||||
m.seenVirtualKey = bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
}
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return &governance.EvaluationResult{Decision: governance.DecisionAllow}, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPreHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest) (*schemas.HTTPResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPostHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest, _ *schemas.HTTPResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PostLLMHook(_ *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
|
||||
return result, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PreMCPHook(_ *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) {
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) PostMCPHook(_ *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) {
|
||||
return resp, bifrostErr, nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRealtimeMintingGovernancePlugin) GetGovernanceStore() governance.GovernanceStore {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_RequiresAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &lib.Config{}
|
||||
plugin := &mockRealtimeMintingGovernancePlugin{
|
||||
err: &schemas.BifrostError{
|
||||
Type: schemas.Ptr("virtual_key_required"),
|
||||
StatusCode: schemas.Ptr(401),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "virtual key is required. Provide a virtual key via the x-bf-vk header.",
|
||||
},
|
||||
},
|
||||
}
|
||||
plugins := []schemas.BasePlugin{plugin}
|
||||
config.BasePlugins.Store(&plugins)
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, config)
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
|
||||
err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime")
|
||||
if err == nil {
|
||||
t.Fatal("expected governance error")
|
||||
}
|
||||
if err.StatusCode == nil {
|
||||
t.Fatal("expected status code")
|
||||
}
|
||||
if got, want := *err.StatusCode, fasthttp.StatusUnauthorized; got != want {
|
||||
t.Fatalf("status = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_PassesContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &lib.Config{}
|
||||
plugin := &mockRealtimeMintingGovernancePlugin{}
|
||||
plugins := []schemas.BasePlugin{
|
||||
plugin,
|
||||
}
|
||||
config.BasePlugins.Store(&plugins)
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, config)
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, "user_123")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-123")
|
||||
|
||||
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
|
||||
t.Fatalf("unexpected governance error: %v", err)
|
||||
}
|
||||
if plugin.evaluateCalls != 1 {
|
||||
t.Fatalf("evaluate calls = %d, want 1", plugin.evaluateCalls)
|
||||
}
|
||||
if plugin.seenUserID != "user_123" {
|
||||
t.Fatalf("governance user id = %q, want %q", plugin.seenUserID, "user_123")
|
||||
}
|
||||
if plugin.seenVirtualKey != "sk-bf-123" {
|
||||
t.Fatalf("virtual key = %q, want %q", plugin.seenVirtualKey, "sk-bf-123")
|
||||
}
|
||||
if plugin.seenProvider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", plugin.seenProvider, schemas.OpenAI)
|
||||
}
|
||||
if plugin.seenModel != "gpt-realtime" {
|
||||
t.Fatalf("model = %q, want %q", plugin.seenModel, "gpt-realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeClientSecretsEvaluateMintingGovernance_ContinuesWithoutGovernance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewRealtimeClientSecretsHandler(nil, &lib.Config{})
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
defer bifrostCtx.Done()
|
||||
|
||||
if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil {
|
||||
t.Fatalf("unexpected governance error without plugin: %v", err)
|
||||
}
|
||||
}
|
||||
441
transports/bifrost-http/handlers/realtime_logging.go
Normal file
441
transports/bifrost-http/handlers/realtime_logging.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
)
|
||||
|
||||
type realtimeTurnSource string
|
||||
|
||||
const (
|
||||
realtimeTurnSourceEI realtimeTurnSource = "ei"
|
||||
realtimeTurnSourceLM realtimeTurnSource = "lm"
|
||||
)
|
||||
|
||||
const (
|
||||
realtimeMissingTranscriptText = "[Audio transcription unavailable]"
|
||||
)
|
||||
|
||||
func extractRealtimeTurnSummary(event *schemas.BifrostRealtimeEvent, contentOverride string) string {
|
||||
if strings.TrimSpace(contentOverride) != "" {
|
||||
return strings.TrimSpace(contentOverride)
|
||||
}
|
||||
if event == nil {
|
||||
return ""
|
||||
}
|
||||
if event.Error != nil && strings.TrimSpace(event.Error.Message) != "" {
|
||||
return strings.TrimSpace(event.Error.Message)
|
||||
}
|
||||
if event.Delta != nil {
|
||||
if text := strings.TrimSpace(event.Delta.Text); text != "" {
|
||||
return text
|
||||
}
|
||||
if transcript := strings.TrimSpace(event.Delta.Transcript); transcript != "" {
|
||||
return transcript
|
||||
}
|
||||
}
|
||||
if event.Item != nil {
|
||||
if summary := extractRealtimeItemSummary(event.Item); summary != "" {
|
||||
return summary
|
||||
}
|
||||
}
|
||||
if event.Session != nil && strings.TrimSpace(event.Session.Instructions) != "" {
|
||||
return strings.TrimSpace(event.Session.Instructions)
|
||||
}
|
||||
if len(event.RawData) > 0 {
|
||||
return strings.TrimSpace(string(event.RawData))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractRealtimeItemSummary(item *schemas.RealtimeItem) string {
|
||||
if item == nil {
|
||||
return ""
|
||||
}
|
||||
if summary := extractRealtimeContentSummary(item.Content); summary != "" {
|
||||
return summary
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(item.Output) != "":
|
||||
return strings.TrimSpace(item.Output)
|
||||
case strings.TrimSpace(item.Arguments) != "":
|
||||
return strings.TrimSpace(item.Arguments)
|
||||
case strings.TrimSpace(item.Name) != "":
|
||||
return strings.TrimSpace(item.Name)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func extractRealtimeContentSummary(raw []byte) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := sonic.Unmarshal(raw, &decoded); err != nil {
|
||||
return strings.TrimSpace(string(raw))
|
||||
}
|
||||
|
||||
var parts []string
|
||||
collectRealtimeTextFragments(decoded, &parts)
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func collectRealtimeTextFragments(value any, parts *[]string) {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
for key, field := range v {
|
||||
switch key {
|
||||
case "text", "transcript", "input_text", "output_text", "output", "arguments":
|
||||
if text, ok := field.(string); ok {
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" {
|
||||
*parts = append(*parts, text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
collectRealtimeTextFragments(field, parts)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
collectRealtimeTextFragments(item, parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func finalizedRealtimeInputSummary(event *schemas.BifrostRealtimeEvent) string {
|
||||
if event == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case schemas.RTEventInputAudioTransCompleted:
|
||||
if transcript := extractRealtimeExtraParamString(event, "transcript"); transcript != "" {
|
||||
return transcript
|
||||
}
|
||||
return realtimeMissingTranscriptText
|
||||
default:
|
||||
if event != nil && event.Type == schemas.RTEventConversationItemDone && schemas.IsRealtimeUserInputEvent(event) {
|
||||
if summary := extractRealtimeItemSummary(event.Item); summary != "" {
|
||||
return summary
|
||||
}
|
||||
if realtimeItemHasMissingAudioTranscript(event.Item) {
|
||||
return realtimeMissingTranscriptText
|
||||
}
|
||||
}
|
||||
if schemas.IsRealtimeUserInputEvent(event) {
|
||||
return extractRealtimeItemSummary(event.Item)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func pendingRealtimeInputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) {
|
||||
if event == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case schemas.RTEventConversationItemRetrieved:
|
||||
return "", ""
|
||||
case schemas.RTEventInputAudioTransCompleted:
|
||||
return realtimeEventItemID(event), finalizedRealtimeInputSummary(event)
|
||||
default:
|
||||
if schemas.IsRealtimeUserInputEvent(event) {
|
||||
return realtimeEventItemID(event), finalizedRealtimeInputSummary(event)
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool {
|
||||
if item == nil || len(item.Content) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var decoded []map[string]any
|
||||
if err := sonic.Unmarshal(item.Content, &decoded); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, part := range decoded {
|
||||
partType, _ := part["type"].(string)
|
||||
if partType != "input_audio" {
|
||||
continue
|
||||
}
|
||||
transcript, exists := part["transcript"]
|
||||
if !exists || transcript == nil {
|
||||
return true
|
||||
}
|
||||
if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func finalizedRealtimeToolOutputSummary(event *schemas.BifrostRealtimeEvent) string {
|
||||
if !schemas.IsRealtimeToolOutputEvent(event) {
|
||||
return ""
|
||||
}
|
||||
return extractRealtimeItemSummary(event.Item)
|
||||
}
|
||||
|
||||
func pendingRealtimeToolOutputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) {
|
||||
if event == nil || event.Type == schemas.RTEventConversationItemRetrieved || !schemas.IsRealtimeToolOutputEvent(event) {
|
||||
return "", ""
|
||||
}
|
||||
return realtimeEventItemID(event), finalizedRealtimeToolOutputSummary(event)
|
||||
}
|
||||
|
||||
func extractRealtimeExtraParamString(event *schemas.BifrostRealtimeEvent, key string) string {
|
||||
if event == nil || event.ExtraParams == nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := event.ExtraParams[key]
|
||||
if !ok || len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func realtimeEventItemID(event *schemas.BifrostRealtimeEvent) string {
|
||||
if event == nil {
|
||||
return ""
|
||||
}
|
||||
if event.Item != nil && strings.TrimSpace(event.Item.ID) != "" {
|
||||
return strings.TrimSpace(event.Item.ID)
|
||||
}
|
||||
if event.Delta != nil && strings.TrimSpace(event.Delta.ItemID) != "" {
|
||||
return strings.TrimSpace(event.Delta.ItemID)
|
||||
}
|
||||
return extractRealtimeExtraParamString(event, "item_id")
|
||||
}
|
||||
|
||||
func combineRealtimeInputRaw(turnInputs []bfws.RealtimeTurnInput) string {
|
||||
var parts []string
|
||||
for _, turnInput := range turnInputs {
|
||||
if trimmed := strings.TrimSpace(turnInput.Raw); trimmed != "" {
|
||||
parts = append(parts, trimmed)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
type realtimeResponseDoneEnvelope struct {
|
||||
Response struct {
|
||||
Output []realtimeResponseDoneOutput `json:"output"`
|
||||
Usage *realtimeResponseDoneUsage `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneOutput struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
CallID string `json:"call_id"`
|
||||
Arguments string `json:"arguments"`
|
||||
Content []realtimeResponseDoneContent `json:"content"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Transcript string `json:"transcript"`
|
||||
Refusal string `json:"refusal"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails *realtimeResponseDoneInputTokenUsage `json:"input_token_details"`
|
||||
OutputTokenDetails *realtimeResponseDoneOutputTokenUsage `json:"output_token_details"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneInputTokenUsage struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type realtimeResponseDoneOutputTokenUsage struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
ImageTokens *int `json:"image_tokens"`
|
||||
CitationTokens *int `json:"citation_tokens"`
|
||||
NumSearchQueries *int `json:"num_search_queries"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
|
||||
func extractRealtimeTurnUsage(provider schemas.RealtimeProvider, rawMessage []byte) *schemas.BifrostLLMUsage {
|
||||
if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok {
|
||||
if usage := extractor.ExtractRealtimeTurnUsage(rawMessage); usage != nil {
|
||||
return usage
|
||||
}
|
||||
}
|
||||
return extractRealtimeResponseDoneUsage(rawMessage)
|
||||
}
|
||||
|
||||
func extractRealtimeTurnOutputMessage(provider schemas.RealtimeProvider, rawMessage []byte, contentSummary string) *schemas.ChatMessage {
|
||||
if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok {
|
||||
if message := extractor.ExtractRealtimeTurnOutput(rawMessage); message != nil {
|
||||
if strings.TrimSpace(contentSummary) != "" && (message.Content == nil || message.Content.ContentStr == nil || strings.TrimSpace(*message.Content.ContentStr) == "") {
|
||||
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentSummary))}
|
||||
}
|
||||
return message
|
||||
}
|
||||
}
|
||||
return buildRealtimeAssistantLogMessage(rawMessage, contentSummary)
|
||||
}
|
||||
|
||||
func buildRealtimeAssistantLogMessage(rawMessage []byte, contentSummary string) *schemas.ChatMessage {
|
||||
contentSummary = strings.TrimSpace(contentSummary)
|
||||
var parsed realtimeResponseDoneEnvelope
|
||||
if len(rawMessage) > 0 && sonic.Unmarshal(rawMessage, &parsed) == nil {
|
||||
message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant}
|
||||
if contentSummary == "" {
|
||||
contentSummary = extractRealtimeResponseDoneAssistantText(parsed.Response.Output)
|
||||
}
|
||||
if contentSummary != "" {
|
||||
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)}
|
||||
}
|
||||
|
||||
toolCalls := extractRealtimeResponseDoneToolCalls(parsed.Response.Output)
|
||||
if len(toolCalls) > 0 {
|
||||
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
if message.Content != nil || message.ChatAssistantMessage != nil {
|
||||
return message
|
||||
}
|
||||
}
|
||||
|
||||
if contentSummary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)},
|
||||
}
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneAssistantText(outputs []realtimeResponseDoneOutput) string {
|
||||
var parts []string
|
||||
for _, output := range outputs {
|
||||
if output.Type != "message" {
|
||||
continue
|
||||
}
|
||||
for _, block := range output.Content {
|
||||
switch {
|
||||
case strings.TrimSpace(block.Text) != "":
|
||||
parts = append(parts, strings.TrimSpace(block.Text))
|
||||
case strings.TrimSpace(block.Transcript) != "":
|
||||
parts = append(parts, strings.TrimSpace(block.Transcript))
|
||||
case strings.TrimSpace(block.Refusal) != "":
|
||||
parts = append(parts, strings.TrimSpace(block.Refusal))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneToolCalls(outputs []realtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall {
|
||||
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
|
||||
for _, output := range outputs {
|
||||
if output.Type != "function_call" {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(output.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolType := "function"
|
||||
id := strings.TrimSpace(output.CallID)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(output.ID)
|
||||
}
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: &toolType,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr(name),
|
||||
Arguments: output.Arguments,
|
||||
},
|
||||
}
|
||||
if id != "" {
|
||||
toolCall.ID = schemas.Ptr(id)
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneUsage(rawMessage []byte) *schemas.BifrostLLMUsage {
|
||||
if len(rawMessage) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parsed realtimeResponseDoneEnvelope
|
||||
if err := sonic.Unmarshal(rawMessage, &parsed); err != nil || parsed.Response.Usage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
totalTokens := parsed.Response.Usage.TotalTokens
|
||||
if totalTokens == 0 && (parsed.Response.Usage.InputTokens > 0 || parsed.Response.Usage.OutputTokens > 0) {
|
||||
totalTokens = parsed.Response.Usage.InputTokens + parsed.Response.Usage.OutputTokens
|
||||
}
|
||||
|
||||
usage := &schemas.BifrostLLMUsage{
|
||||
PromptTokens: parsed.Response.Usage.InputTokens,
|
||||
CompletionTokens: parsed.Response.Usage.OutputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
}
|
||||
|
||||
if parsed.Response.Usage.InputTokenDetails != nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens,
|
||||
AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens,
|
||||
ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens,
|
||||
CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Response.Usage.OutputTokenDetails != nil {
|
||||
usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
|
||||
TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens,
|
||||
AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens,
|
||||
ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens,
|
||||
ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens,
|
||||
CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens,
|
||||
NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries,
|
||||
AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens,
|
||||
RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return usage
|
||||
}
|
||||
435
transports/bifrost-http/handlers/realtime_logging_test.go
Normal file
435
transports/bifrost-http/handlers/realtime_logging_test.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
)
|
||||
|
||||
func TestShouldAccumulateRealtimeOutput(t *testing.T) {
|
||||
provider := &openai.OpenAIProvider{}
|
||||
if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseTextDelta) {
|
||||
t.Fatal("expected response.text.delta to accumulate output text")
|
||||
}
|
||||
if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseAudioTransDelta) {
|
||||
t.Fatal("expected response.audio_transcript.delta to accumulate output transcript")
|
||||
}
|
||||
if provider.ShouldAccumulateRealtimeOutput(schemas.RTEventInputAudioTransDelta) {
|
||||
t.Fatal("did not expect input audio transcription delta to accumulate assistant output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealtimeTurnSummary(t *testing.T) {
|
||||
event := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from realtime"}]`),
|
||||
},
|
||||
}
|
||||
|
||||
got := extractRealtimeTurnSummary(event, "")
|
||||
if got != "hello from realtime" {
|
||||
t.Fatalf("extractRealtimeTurnSummary() = %q, want %q", got, "hello from realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizedRealtimeInputSummary(t *testing.T) {
|
||||
userCreate := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from browser"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userCreate); got != "hello from browser" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user create) = %q, want %q", got, "hello from browser")
|
||||
}
|
||||
|
||||
userRetrieved := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from retrieved item"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userRetrieved); got != "hello from retrieved item" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user retrieved) = %q, want %q", got, "hello from retrieved item")
|
||||
}
|
||||
|
||||
userCreated := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreated,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from provider created item"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userCreated); got != "hello from provider created item" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user created) = %q, want %q", got, "hello from provider created item")
|
||||
}
|
||||
|
||||
userAdded := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemAdded,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"hello from provider added item"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userAdded); got != "hello from provider added item" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user added) = %q, want %q", got, "hello from provider added item")
|
||||
}
|
||||
|
||||
userCreatedWithoutTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreated,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Type: "message",
|
||||
Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`),
|
||||
},
|
||||
RawData: []byte(`{"type":"conversation.item.created","item":{"type":"message","role":"user","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userCreatedWithoutTranscript); got != "" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user created without transcript) = %q, want empty", got)
|
||||
}
|
||||
|
||||
userDoneWithoutTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemDone,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "user",
|
||||
Type: "message",
|
||||
Status: "completed",
|
||||
Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`),
|
||||
},
|
||||
RawData: []byte(`{"type":"conversation.item.done","item":{"type":"message","role":"user","status":"completed","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(userDoneWithoutTranscript); got != realtimeMissingTranscriptText {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(user done without transcript) = %q, want %q", got, realtimeMissingTranscriptText)
|
||||
}
|
||||
|
||||
inputTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"transcript": json.RawMessage(`"spoken user turn"`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(inputTranscript); got != "spoken user turn" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(input transcript) = %q, want %q", got, "spoken user turn")
|
||||
}
|
||||
|
||||
emptyInputTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"transcript": json.RawMessage(`""`),
|
||||
},
|
||||
RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","transcript":"","usage":{"total_tokens":11}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(emptyInputTranscript); got != realtimeMissingTranscriptText {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(empty input transcript) = %q, want %q", got, realtimeMissingTranscriptText)
|
||||
}
|
||||
|
||||
missingInputTranscript := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","usage":{"total_tokens":11}}`),
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(missingInputTranscript); got != realtimeMissingTranscriptText {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(missing input transcript) = %q, want %q", got, realtimeMissingTranscriptText)
|
||||
}
|
||||
|
||||
assistantCreate := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Role: "assistant",
|
||||
Content: []byte(`[{"type":"text","text":"assistant text"}]`),
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeInputSummary(assistantCreate); got != "" {
|
||||
t.Fatalf("finalizedRealtimeInputSummary(assistant create) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizedRealtimeToolOutputSummary(t *testing.T) {
|
||||
event := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreate,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(event); got != `{"nextResponse":"tool result"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary() = %q, want %q", got, `{"nextResponse":"tool result"}`)
|
||||
}
|
||||
|
||||
retrieved := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result from retrieved"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(retrieved); got != `{"nextResponse":"tool result from retrieved"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary(retrieved) = %q, want %q", got, `{"nextResponse":"tool result from retrieved"}`)
|
||||
}
|
||||
|
||||
created := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemCreated,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result from created"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(created); got != `{"nextResponse":"tool result from created"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary(created) = %q, want %q", got, `{"nextResponse":"tool result from created"}`)
|
||||
}
|
||||
|
||||
added := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemAdded,
|
||||
Item: &schemas.RealtimeItem{
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result from added"}`,
|
||||
},
|
||||
}
|
||||
if got := finalizedRealtimeToolOutputSummary(added); got != `{"nextResponse":"tool result from added"}` {
|
||||
t.Fatalf("finalizedRealtimeToolOutputSummary(added) = %q, want %q", got, `{"nextResponse":"tool result from added"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingRealtimeInputUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transcriptEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventInputAudioTransCompleted,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"item_id": json.RawMessage(`"item_123"`),
|
||||
"transcript": json.RawMessage(`"Hello."`),
|
||||
},
|
||||
}
|
||||
itemID, summary := pendingRealtimeInputUpdate(transcriptEvent)
|
||||
if itemID != "item_123" || summary != "Hello." {
|
||||
t.Fatalf("pendingRealtimeInputUpdate(transcript) = (%q, %q), want (%q, %q)", itemID, summary, "item_123", "Hello.")
|
||||
}
|
||||
|
||||
retrievedEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_123",
|
||||
Role: "user",
|
||||
Content: []byte(`[{"type":"input_text","text":"historical hello"}]`),
|
||||
},
|
||||
}
|
||||
itemID, summary = pendingRealtimeInputUpdate(retrievedEvent)
|
||||
if itemID != "" || summary != "" {
|
||||
t.Fatalf("pendingRealtimeInputUpdate(retrieved) = (%q, %q), want empty", itemID, summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingRealtimeToolOutputUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
toolOutputEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemDone,
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_tool_123",
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"tool result"}`,
|
||||
},
|
||||
}
|
||||
itemID, summary := pendingRealtimeToolOutputUpdate(toolOutputEvent)
|
||||
if itemID != "item_tool_123" || summary != `{"nextResponse":"tool result"}` {
|
||||
t.Fatalf("pendingRealtimeToolOutputUpdate(done) = (%q, %q), want (%q, %q)", itemID, summary, "item_tool_123", `{"nextResponse":"tool result"}`)
|
||||
}
|
||||
|
||||
retrievedToolOutputEvent := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventConversationItemRetrieved,
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_tool_123",
|
||||
Type: "function_call_output",
|
||||
Output: `{"nextResponse":"historical tool result"}`,
|
||||
},
|
||||
}
|
||||
itemID, summary = pendingRealtimeToolOutputUpdate(retrievedToolOutputEvent)
|
||||
if itemID != "" || summary != "" {
|
||||
t.Fatalf("pendingRealtimeToolOutputUpdate(retrieved) = (%q, %q), want empty", itemID, summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRealtimeTurnPostResponseUsesFullResponseDonePayload(t *testing.T) {
|
||||
rawRequest := `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}`
|
||||
rawResponse := []byte(`{
|
||||
"type":"response.done",
|
||||
"response":{
|
||||
"output":[
|
||||
{
|
||||
"id":"item_message_123",
|
||||
"type":"message",
|
||||
"content":[
|
||||
{
|
||||
"type":"audio",
|
||||
"transcript":"assistant turn text"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage":{
|
||||
"total_tokens":26,
|
||||
"input_tokens":17,
|
||||
"output_tokens":9,
|
||||
"input_token_details":{
|
||||
"text_tokens":12,
|
||||
"audio_tokens":5,
|
||||
"image_tokens":0,
|
||||
"cached_tokens":4
|
||||
},
|
||||
"output_token_details":{
|
||||
"text_tokens":7,
|
||||
"audio_tokens":2
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
resp := buildRealtimeTurnPostResponse(&openai.OpenAIProvider{}, schemas.OpenAI, "gpt-4o-realtime-preview-2025-06-03", rawRequest, rawResponse, "", 4321)
|
||||
if resp == nil || resp.ResponsesResponse == nil {
|
||||
t.Fatal("expected realtime post response to be built")
|
||||
}
|
||||
if resp.ResponsesResponse.ExtraFields.Latency != 4321 {
|
||||
t.Fatalf("Latency = %d, want %d", resp.ResponsesResponse.ExtraFields.Latency, 4321)
|
||||
}
|
||||
if resp.ResponsesResponse.Usage == nil || resp.ResponsesResponse.Usage.InputTokens != 17 || resp.ResponsesResponse.Usage.OutputTokens != 9 || resp.ResponsesResponse.Usage.TotalTokens != 26 {
|
||||
t.Fatalf("Usage = %+v, want input=17 output=9 total=26", resp.ResponsesResponse.Usage)
|
||||
}
|
||||
if len(resp.ResponsesResponse.Output) != 1 {
|
||||
t.Fatalf("len(Output) = %d, want 1", len(resp.ResponsesResponse.Output))
|
||||
}
|
||||
if resp.ResponsesResponse.Output[0].Content == nil || resp.ResponsesResponse.Output[0].Content.ContentStr == nil || *resp.ResponsesResponse.Output[0].Content.ContentStr != "assistant turn text" {
|
||||
t.Fatalf("Output[0].Content = %+v, want assistant turn text", resp.ResponsesResponse.Output[0].Content)
|
||||
}
|
||||
if got, ok := resp.ResponsesResponse.ExtraFields.RawRequest.(string); !ok || got != rawRequest {
|
||||
t.Fatalf("RawRequest = %#v, want %q", resp.ResponsesResponse.ExtraFields.RawRequest, rawRequest)
|
||||
}
|
||||
if got, ok := resp.ResponsesResponse.ExtraFields.RawResponse.(string); !ok || got == "" {
|
||||
t.Fatalf("RawResponse = %#v, want raw response string", resp.ResponsesResponse.ExtraFields.RawResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRealtimeTurnHooksWithErrorCompletesActiveHooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := bfws.NewSession(nil)
|
||||
session.SetProviderSessionID("sess_provider_123")
|
||||
session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`)
|
||||
session.AppendRealtimeOutputText("partial assistant output")
|
||||
|
||||
var (
|
||||
capturedResp *schemas.BifrostResponse
|
||||
capturedErr *schemas.BifrostError
|
||||
cleanedUp bool
|
||||
)
|
||||
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
|
||||
RequestID: "req_realtime_123",
|
||||
StartedAt: time.Now().Add(-time.Second),
|
||||
PreHookValues: map[any]any{
|
||||
schemas.BifrostContextKeyGovernanceVirtualKeyID: "vk_123",
|
||||
},
|
||||
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
capturedResp = result
|
||||
capturedErr = err
|
||||
return result, nil
|
||||
},
|
||||
Cleanup: func() {
|
||||
cleanedUp = true
|
||||
},
|
||||
})
|
||||
|
||||
rawResponse := []byte(`{"type":"error","error":{"type":"server_error","message":"Virtual key is required."}}`)
|
||||
postErr := finalizeRealtimeTurnHooksWithError(
|
||||
nil,
|
||||
nil,
|
||||
session,
|
||||
schemas.OpenAI,
|
||||
"gpt-realtime",
|
||||
nil,
|
||||
schemas.RTEventError,
|
||||
rawResponse,
|
||||
newRealtimeWireBifrostError(401, "server_error", "Virtual key is required."),
|
||||
)
|
||||
if postErr != nil {
|
||||
t.Fatalf("finalizeRealtimeTurnHooksWithError() post error = %v, want nil", postErr)
|
||||
}
|
||||
if capturedResp != nil {
|
||||
t.Fatalf("captured response = %#v, want nil", capturedResp)
|
||||
}
|
||||
if capturedErr == nil {
|
||||
t.Fatal("expected captured error")
|
||||
}
|
||||
if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest)
|
||||
}
|
||||
if capturedErr.ExtraFields.Provider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", capturedErr.ExtraFields.Provider, schemas.OpenAI)
|
||||
}
|
||||
if capturedErr.ExtraFields.OriginalModelRequested != "gpt-realtime" {
|
||||
t.Fatalf("model requested = %q, want %q", capturedErr.ExtraFields.OriginalModelRequested, "gpt-realtime")
|
||||
}
|
||||
rawRequest, ok := capturedErr.ExtraFields.RawRequest.(string)
|
||||
if !ok || rawRequest == "" {
|
||||
t.Fatalf("raw request = %#v, want non-empty string", capturedErr.ExtraFields.RawRequest)
|
||||
}
|
||||
rawResp, ok := capturedErr.ExtraFields.RawResponse.(json.RawMessage)
|
||||
if !ok || string(rawResp) != string(rawResponse) {
|
||||
t.Fatalf("raw response = %#v, want %s", capturedErr.ExtraFields.RawResponse, string(rawResponse))
|
||||
}
|
||||
if session.PeekRealtimeTurnHooks() != nil {
|
||||
t.Fatal("expected active hooks to be cleared")
|
||||
}
|
||||
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
|
||||
t.Fatalf("remaining turn inputs = %d, want 0", len(got))
|
||||
}
|
||||
if got := session.ConsumeRealtimeOutputText(); got != "" {
|
||||
t.Fatalf("remaining output text = %q, want empty", got)
|
||||
}
|
||||
if !cleanedUp {
|
||||
t.Fatal("expected realtime hook cleanup to run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBifrostErrorFromRealtimeErrorCarriesRealtimeMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rawResponse := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request_error","message":"bad request","param":"session.type"}}`)
|
||||
bifrostErr := newBifrostErrorFromRealtimeError(
|
||||
schemas.OpenAI,
|
||||
"gpt-realtime",
|
||||
rawResponse,
|
||||
&schemas.RealtimeError{
|
||||
Type: "invalid_request_error",
|
||||
Code: "invalid_request_error",
|
||||
Message: "bad request",
|
||||
Param: "session.type",
|
||||
},
|
||||
)
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("expected bifrost error")
|
||||
}
|
||||
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 400 {
|
||||
t.Fatalf("status code = %#v, want 400", bifrostErr.StatusCode)
|
||||
}
|
||||
if bifrostErr.ExtraFields.RequestType != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %q, want %q", bifrostErr.ExtraFields.RequestType, schemas.RealtimeRequest)
|
||||
}
|
||||
if bifrostErr.ExtraFields.Provider != schemas.OpenAI {
|
||||
t.Fatalf("provider = %q, want %q", bifrostErr.ExtraFields.Provider, schemas.OpenAI)
|
||||
}
|
||||
if bifrostErr.ExtraFields.OriginalModelRequested != "gpt-realtime" {
|
||||
t.Fatalf("model requested = %q, want %q", bifrostErr.ExtraFields.OriginalModelRequested, "gpt-realtime")
|
||||
}
|
||||
rawResp, ok := bifrostErr.ExtraFields.RawResponse.(json.RawMessage)
|
||||
if !ok || string(rawResp) != string(rawResponse) {
|
||||
t.Fatalf("raw response = %#v, want %s", bifrostErr.ExtraFields.RawResponse, string(rawResponse))
|
||||
}
|
||||
if bifrostErr.Error == nil || bifrostErr.Error.Param != "session.type" {
|
||||
t.Fatalf("error param = %#v, want session.type", bifrostErr.Error)
|
||||
}
|
||||
}
|
||||
798
transports/bifrost-http/handlers/realtime_turn_pipeline.go
Normal file
798
transports/bifrost-http/handlers/realtime_turn_pipeline.go
Normal file
@@ -0,0 +1,798 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
)
|
||||
|
||||
func newRealtimeTurnContext(
|
||||
baseCtx *schemas.BifrostContext,
|
||||
requestID string,
|
||||
sessionID string,
|
||||
providerSessionID string,
|
||||
source realtimeTurnSource,
|
||||
eventType schemas.RealtimeEventType,
|
||||
key *schemas.Key,
|
||||
) *schemas.BifrostContext {
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
if baseCtx != nil {
|
||||
// Realtime post-hook contexts must preserve plugin-private values written in
|
||||
// pre-hooks (for example telemetry start timestamps), not just public keys.
|
||||
for ctxKey, value := range baseCtx.GetUserValues() {
|
||||
if value != nil {
|
||||
ctx.SetValue(ctxKey, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
if requestID == "" {
|
||||
requestID = uuid.NewString()
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
|
||||
resolvedSessionID := strings.TrimSpace(providerSessionID)
|
||||
if resolvedSessionID == "" {
|
||||
resolvedSessionID = strings.TrimSpace(sessionID)
|
||||
}
|
||||
if baseCtx != nil {
|
||||
if externalSessionID, ok := baseCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && strings.TrimSpace(externalSessionID) != "" {
|
||||
resolvedSessionID = strings.TrimSpace(externalSessionID)
|
||||
}
|
||||
}
|
||||
if resolvedSessionID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, resolvedSessionID)
|
||||
}
|
||||
if strings.TrimSpace(providerSessionID) != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeSessionID, providerSessionID)
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeProviderSessionID, providerSessionID)
|
||||
}
|
||||
if source != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeSource, string(source))
|
||||
}
|
||||
if eventType != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyRealtimeEventType, string(eventType))
|
||||
}
|
||||
if key != nil {
|
||||
if strings.TrimSpace(key.ID) != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID)
|
||||
}
|
||||
if strings.TrimSpace(key.Name) != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name)
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func applyRealtimeTurnContextValues(ctx *schemas.BifrostContext, values map[any]any) {
|
||||
if ctx == nil || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
for ctxKey, value := range values {
|
||||
switch ctxKey {
|
||||
case schemas.BifrostContextKeyRequestID,
|
||||
schemas.BifrostContextKeyParentRequestID,
|
||||
schemas.BifrostContextKeyRealtimeSessionID,
|
||||
schemas.BifrostContextKeyRealtimeProviderSessionID,
|
||||
schemas.BifrostContextKeyRealtimeSource,
|
||||
schemas.BifrostContextKeyRealtimeEventType,
|
||||
schemas.BifrostContextKeyStreamStartTime,
|
||||
schemas.BifrostContextKeyStreamEndIndicator:
|
||||
continue
|
||||
}
|
||||
if value != nil {
|
||||
ctx.SetValue(ctxKey, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setRealtimeTurnStreamContext(ctx *schemas.BifrostContext, startedAt time.Time, isFinal bool) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if startedAt.IsZero() {
|
||||
startedAt = time.Now()
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, startedAt)
|
||||
if isFinal {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
}
|
||||
}
|
||||
|
||||
func buildRealtimeTurnPreRequest(provider schemas.ModelProvider, model string, turnInputs []bfws.RealtimeTurnInput) *schemas.BifrostRequest {
|
||||
input := make([]schemas.ResponsesMessage, 0, len(turnInputs))
|
||||
for _, turnInput := range turnInputs {
|
||||
summary := strings.TrimSpace(turnInput.Summary)
|
||||
if summary == "" {
|
||||
continue
|
||||
}
|
||||
switch turnInput.Role {
|
||||
case string(schemas.ChatMessageRoleTool):
|
||||
itemType := schemas.ResponsesMessageTypeFunctionCallOutput
|
||||
output := &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesToolCallOutputStr: schemas.Ptr(summary),
|
||||
}
|
||||
input = append(input, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{Output: output},
|
||||
})
|
||||
case string(schemas.ChatMessageRoleUser):
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleUser
|
||||
input = append(input, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(summary)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostRequest{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
ResponsesRequest: &schemas.BifrostResponsesRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildRealtimeTurnPostResponse(
|
||||
rtProvider schemas.RealtimeProvider,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
rawRequest string,
|
||||
rawResponse []byte,
|
||||
contentOverride string,
|
||||
latency int64,
|
||||
) *schemas.BifrostResponse {
|
||||
output := buildRealtimeTurnOutputMessages(rtProvider, rawResponse, contentOverride)
|
||||
resp := &schemas.BifrostResponsesResponse{
|
||||
Object: "response",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
Provider: provider,
|
||||
OriginalModelRequested: model,
|
||||
Latency: latency,
|
||||
},
|
||||
}
|
||||
if usage := extractRealtimeTurnUsage(rtProvider, rawResponse); usage != nil {
|
||||
resp.Usage = buildRealtimeResponsesUsage(usage)
|
||||
}
|
||||
if strings.TrimSpace(rawRequest) != "" {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if len(rawResponse) > 0 {
|
||||
resp.ExtraFields.RawResponse = string(rawResponse)
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponse{ResponsesResponse: resp}
|
||||
}
|
||||
|
||||
func buildRealtimeTurnOutputMessages(rtProvider schemas.RealtimeProvider, rawResponse []byte, contentOverride string) []schemas.ResponsesMessage {
|
||||
outputs := make([]schemas.ResponsesMessage, 0)
|
||||
if outputMessage := extractRealtimeTurnOutputMessage(rtProvider, rawResponse, contentOverride); outputMessage != nil {
|
||||
outputs = append(outputs, buildRealtimeResponsesMessagesFromChat(outputMessage, contentOverride)...)
|
||||
}
|
||||
|
||||
if len(outputs) > 0 {
|
||||
return outputs
|
||||
}
|
||||
|
||||
var parsed realtimeResponseDoneEnvelope
|
||||
if len(rawResponse) > 0 && schemas.Unmarshal(rawResponse, &parsed) == nil {
|
||||
for _, item := range parsed.Response.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
content := strings.TrimSpace(contentOverride)
|
||||
if content == "" {
|
||||
content = extractRealtimeResponseDoneContentText(item.Content)
|
||||
}
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
msg := schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Status: schemas.Ptr("completed"),
|
||||
}
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
msg.ID = schemas.Ptr(strings.TrimSpace(item.ID))
|
||||
}
|
||||
if content != "" {
|
||||
msg.Content = &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)}
|
||||
}
|
||||
outputs = append(outputs, msg)
|
||||
case "function_call":
|
||||
itemType := schemas.ResponsesMessageTypeFunctionCall
|
||||
msg := schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Name: schemas.Ptr(strings.TrimSpace(item.Name)),
|
||||
Arguments: schemas.Ptr(item.Arguments),
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(item.ID) != "" {
|
||||
msg.ID = schemas.Ptr(strings.TrimSpace(item.ID))
|
||||
}
|
||||
if strings.TrimSpace(item.CallID) != "" {
|
||||
msg.CallID = schemas.Ptr(strings.TrimSpace(item.CallID))
|
||||
}
|
||||
outputs = append(outputs, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(outputs) == 0 && strings.TrimSpace(contentOverride) != "" {
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
outputs = append(outputs, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Status: schemas.Ptr("completed"),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentOverride))},
|
||||
})
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
func buildRealtimeResponsesMessagesFromChat(message *schemas.ChatMessage, contentOverride string) []schemas.ResponsesMessage {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
outputs := make([]schemas.ResponsesMessage, 0, 1)
|
||||
content := strings.TrimSpace(contentOverride)
|
||||
if content == "" && message.Content != nil && message.Content.ContentStr != nil {
|
||||
content = strings.TrimSpace(*message.Content.ContentStr)
|
||||
}
|
||||
if content != "" {
|
||||
itemType := schemas.ResponsesMessageTypeMessage
|
||||
role := schemas.ResponsesInputMessageRoleAssistant
|
||||
outputs = append(outputs, schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Role: &role,
|
||||
Status: schemas.Ptr("completed"),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)},
|
||||
})
|
||||
}
|
||||
|
||||
if message.ChatAssistantMessage == nil {
|
||||
return outputs
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ChatAssistantMessage.ToolCalls {
|
||||
itemType := schemas.ResponsesMessageTypeFunctionCall
|
||||
msg := schemas.ResponsesMessage{
|
||||
Type: &itemType,
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
Arguments: schemas.Ptr(toolCall.Function.Arguments),
|
||||
},
|
||||
}
|
||||
if toolCall.Function.Name != nil {
|
||||
msg.ResponsesToolMessage.Name = schemas.Ptr(strings.TrimSpace(*toolCall.Function.Name))
|
||||
}
|
||||
if toolCall.ID != nil {
|
||||
msg.CallID = schemas.Ptr(strings.TrimSpace(*toolCall.ID))
|
||||
msg.ID = schemas.Ptr(strings.TrimSpace(*toolCall.ID))
|
||||
}
|
||||
outputs = append(outputs, msg)
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
func extractRealtimeResponseDoneContentText(content []realtimeResponseDoneContent) string {
|
||||
for _, block := range content {
|
||||
switch {
|
||||
case strings.TrimSpace(block.Text) != "":
|
||||
return strings.TrimSpace(block.Text)
|
||||
case strings.TrimSpace(block.Transcript) != "":
|
||||
return strings.TrimSpace(block.Transcript)
|
||||
case strings.TrimSpace(block.Refusal) != "":
|
||||
return strings.TrimSpace(block.Refusal)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildRealtimeResponsesUsage(usage *schemas.BifrostLLMUsage) *schemas.ResponsesResponseUsage {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
result := &schemas.ResponsesResponseUsage{
|
||||
InputTokens: usage.PromptTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
}
|
||||
if usage.PromptTokensDetails != nil {
|
||||
result.InputTokensDetails = &schemas.ResponsesResponseInputTokens{
|
||||
TextTokens: usage.PromptTokensDetails.TextTokens,
|
||||
AudioTokens: usage.PromptTokensDetails.AudioTokens,
|
||||
ImageTokens: usage.PromptTokensDetails.ImageTokens,
|
||||
CachedReadTokens: usage.PromptTokensDetails.CachedReadTokens,
|
||||
CachedWriteTokens: usage.PromptTokensDetails.CachedWriteTokens,
|
||||
}
|
||||
}
|
||||
if usage.CompletionTokensDetails != nil {
|
||||
result.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{
|
||||
TextTokens: usage.CompletionTokensDetails.TextTokens,
|
||||
AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens,
|
||||
AudioTokens: usage.CompletionTokensDetails.AudioTokens,
|
||||
ImageTokens: usage.CompletionTokensDetails.ImageTokens,
|
||||
ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens,
|
||||
RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens,
|
||||
CitationTokens: usage.CompletionTokensDetails.CitationTokens,
|
||||
NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func newRealtimeTurnErrorEventPayload(bifrostErr *schemas.BifrostError) []byte {
|
||||
if bifrostErr == nil {
|
||||
return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`)
|
||||
}
|
||||
|
||||
errorType, errorCode, errorMessage, errorParam := mapRealtimeWireErrorFields(bifrostErr)
|
||||
payload := schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventError,
|
||||
Error: &schemas.RealtimeError{
|
||||
Type: errorType,
|
||||
Code: errorCode,
|
||||
Message: errorMessage,
|
||||
Param: errorParam,
|
||||
},
|
||||
}
|
||||
if data, err := schemas.Marshal(payload); err == nil {
|
||||
return data
|
||||
}
|
||||
return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`)
|
||||
}
|
||||
|
||||
// isBudgetOrBillingError returns true if the lowercased value indicates a budget or billing exhaustion error.
|
||||
// Quota/rate-limit patterns (quota_exceeded, quota exceeded, etc.) are already covered by bifrost.IsRateLimitErrorMessage.
|
||||
func isBudgetOrBillingError(lower string) bool {
|
||||
return strings.Contains(lower, "budget_exceeded") ||
|
||||
strings.Contains(lower, "budget exceeded") ||
|
||||
strings.Contains(lower, "insufficient_quota") ||
|
||||
strings.Contains(lower, "hard limit reached") ||
|
||||
strings.Contains(lower, "billing hard limit")
|
||||
}
|
||||
|
||||
func mapRealtimeWireErrorFields(bifrostErr *schemas.BifrostError) (string, string, string, string) {
|
||||
errorType := "server_error"
|
||||
errorCode := "server_error"
|
||||
errorMessage := "internal server error"
|
||||
errorParam := ""
|
||||
|
||||
if bifrostErr == nil {
|
||||
return errorType, errorCode, errorMessage, errorParam
|
||||
}
|
||||
|
||||
var values []string
|
||||
if bifrostErr.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Type))
|
||||
}
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Type))
|
||||
}
|
||||
if bifrostErr.Error.Code != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Code))
|
||||
}
|
||||
if strings.TrimSpace(bifrostErr.Error.Message) != "" {
|
||||
errorMessage = strings.TrimSpace(bifrostErr.Error.Message)
|
||||
values = append(values, errorMessage)
|
||||
}
|
||||
if bifrostErr.Error.Param != nil {
|
||||
errorParam = strings.TrimSpace(fmt.Sprint(bifrostErr.Error.Param))
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
lower := strings.ToLower(value)
|
||||
switch {
|
||||
case lower == "":
|
||||
continue
|
||||
case strings.Contains(lower, "invalid_request_error"):
|
||||
return "invalid_request_error", "invalid_request_error", errorMessage, errorParam
|
||||
case isBudgetOrBillingError(lower):
|
||||
return "insufficient_quota", "insufficient_quota", errorMessage, errorParam
|
||||
case bifrost.IsRateLimitErrorMessage(lower):
|
||||
return "rate_limit_exceeded", "rate_limit_exceeded", errorMessage, errorParam
|
||||
}
|
||||
}
|
||||
|
||||
return errorType, errorCode, errorMessage, errorParam
|
||||
}
|
||||
|
||||
func shouldGracefullyDisconnectRealtime(bifrostErr *schemas.BifrostError) bool {
|
||||
if bifrostErr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var values []string
|
||||
if bifrostErr.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Type))
|
||||
}
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Type != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Type))
|
||||
}
|
||||
if bifrostErr.Error.Code != nil {
|
||||
values = append(values, strings.TrimSpace(*bifrostErr.Error.Code))
|
||||
}
|
||||
values = append(values, strings.TrimSpace(bifrostErr.Error.Message))
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
lower := strings.ToLower(value)
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if isBudgetOrBillingError(lower) || bifrost.IsRateLimitErrorMessage(lower) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func startRealtimeTurnHooks(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
rtProvider schemas.RealtimeProvider,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
startEventType schemas.RealtimeEventType,
|
||||
) *schemas.BifrostError {
|
||||
if client == nil || session == nil {
|
||||
return &schemas.BifrostError{
|
||||
Type: schemas.Ptr("server_error"),
|
||||
StatusCode: schemas.Ptr(500),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr("server_error"),
|
||||
Message: "realtime turn pipeline is unavailable",
|
||||
},
|
||||
}
|
||||
}
|
||||
if !session.TryBeginRealtimeTurnHooks() {
|
||||
return &schemas.BifrostError{
|
||||
Type: schemas.Ptr("invalid_request_error"),
|
||||
StatusCode: schemas.Ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr("invalid_request_error"),
|
||||
Message: "Conversation already has an active response in progress.",
|
||||
},
|
||||
}
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
session.AbortRealtimeTurnHooks()
|
||||
}
|
||||
}()
|
||||
|
||||
startedAt := time.Now()
|
||||
turnCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, startEventType, key)
|
||||
setRealtimeTurnStreamContext(turnCtx, startedAt, false)
|
||||
req := buildRealtimeTurnPreRequest(provider, model, session.PeekRealtimeTurnInputs())
|
||||
hooks, bifrostErr := client.RunRealtimeTurnPreHooks(turnCtx, req)
|
||||
if bifrostErr != nil {
|
||||
// RunRealtimeTurnPreHooks already executed post-hooks and flushed the trace
|
||||
// for this turn-start failure. Clear buffered turn state so transport-close
|
||||
// fallback finalization does not emit the same error a second time.
|
||||
session.ConsumeRealtimeTurnInputs()
|
||||
session.ConsumeRealtimeOutputText()
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
requestID, _ := turnCtx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
|
||||
PostHookRunner: hooks.PostHookRunner,
|
||||
Cleanup: hooks.Cleanup,
|
||||
RequestID: requestID,
|
||||
StartedAt: startedAt,
|
||||
PreHookValues: turnCtx.GetUserValues(),
|
||||
})
|
||||
committed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func finalizeRealtimeTurnHooks(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
rtProvider schemas.RealtimeProvider,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
rawResponse []byte,
|
||||
contentOverride string,
|
||||
) *schemas.BifrostError {
|
||||
if client == nil || session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
rawRequest := combineRealtimeInputRaw(turnInputs)
|
||||
|
||||
if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil {
|
||||
defer func() {
|
||||
if activeHooks.Cleanup != nil {
|
||||
activeHooks.Cleanup()
|
||||
}
|
||||
}()
|
||||
postResponse := buildRealtimeTurnPostResponse(
|
||||
rtProvider,
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
contentOverride,
|
||||
time.Since(activeHooks.StartedAt).Milliseconds(),
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key)
|
||||
applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues)
|
||||
setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true)
|
||||
_, bifrostErr := activeHooks.PostHookRunner(postCtx, postResponse, nil)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key)
|
||||
setRealtimeTurnStreamContext(preCtx, startedAt, false)
|
||||
preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs)
|
||||
hooks, bifrostErr := client.RunRealtimeTurnPreHooks(preCtx, preReq)
|
||||
if bifrostErr != nil {
|
||||
return bifrostErr
|
||||
}
|
||||
if hooks.Cleanup != nil {
|
||||
defer hooks.Cleanup()
|
||||
}
|
||||
|
||||
requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
postResponse := buildRealtimeTurnPostResponse(
|
||||
rtProvider,
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
contentOverride,
|
||||
time.Since(startedAt).Milliseconds(),
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key)
|
||||
applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues())
|
||||
setRealtimeTurnStreamContext(postCtx, startedAt, true)
|
||||
_, bifrostErr = hooks.PostHookRunner(postCtx, postResponse, nil)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func finalizeRealtimeTurnHooksWithError(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
eventType schemas.RealtimeEventType,
|
||||
rawResponse []byte,
|
||||
bifrostErr *schemas.BifrostError,
|
||||
) *schemas.BifrostError {
|
||||
if session == nil || bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
rawRequest := combineRealtimeInputRaw(turnInputs)
|
||||
session.ConsumeRealtimeOutputText()
|
||||
|
||||
if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil {
|
||||
defer func() {
|
||||
if activeHooks.Cleanup != nil {
|
||||
activeHooks.Cleanup()
|
||||
}
|
||||
}()
|
||||
postErr := buildRealtimeTurnPostError(
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
bifrostErr,
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key)
|
||||
applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues)
|
||||
setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true)
|
||||
_, hookErr := activeHooks.PostHookRunner(postCtx, nil, postErr)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return hookErr
|
||||
}
|
||||
|
||||
if len(turnInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key)
|
||||
setRealtimeTurnStreamContext(preCtx, startedAt, false)
|
||||
preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs)
|
||||
hooks, hookPreErr := client.RunRealtimeTurnPreHooks(preCtx, preReq)
|
||||
if hookPreErr != nil {
|
||||
return hookPreErr
|
||||
}
|
||||
if hooks.Cleanup != nil {
|
||||
defer hooks.Cleanup()
|
||||
}
|
||||
|
||||
requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string)
|
||||
postErr := buildRealtimeTurnPostError(
|
||||
provider,
|
||||
model,
|
||||
rawRequest,
|
||||
rawResponse,
|
||||
bifrostErr,
|
||||
)
|
||||
postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key)
|
||||
applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues())
|
||||
setRealtimeTurnStreamContext(postCtx, startedAt, true)
|
||||
_, hookErr := hooks.PostHookRunner(postCtx, nil, postErr)
|
||||
completeRealtimeTurnTrace(postCtx)
|
||||
return hookErr
|
||||
}
|
||||
|
||||
func buildRealtimeTurnPostError(
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
rawRequest string,
|
||||
rawResponse []byte,
|
||||
bifrostErr *schemas.BifrostError,
|
||||
) *schemas.BifrostError {
|
||||
if bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copied := *bifrostErr
|
||||
copied.ExtraFields = bifrostErr.ExtraFields
|
||||
if bifrostErr.Error != nil {
|
||||
errorCopy := *bifrostErr.Error
|
||||
copied.Error = &errorCopy
|
||||
}
|
||||
copied.ExtraFields.RequestType = schemas.RealtimeRequest
|
||||
if copied.ExtraFields.Provider == "" {
|
||||
copied.ExtraFields.Provider = provider
|
||||
}
|
||||
if strings.TrimSpace(copied.ExtraFields.OriginalModelRequested) == "" {
|
||||
copied.ExtraFields.OriginalModelRequested = model
|
||||
}
|
||||
if strings.TrimSpace(rawRequest) != "" && copied.ExtraFields.RawRequest == nil {
|
||||
copied.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
if len(rawResponse) > 0 && copied.ExtraFields.RawResponse == nil {
|
||||
copied.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...))
|
||||
}
|
||||
return &copied
|
||||
}
|
||||
|
||||
func newBifrostErrorFromRealtimeError(
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
rawResponse []byte,
|
||||
realtimeErr *schemas.RealtimeError,
|
||||
) *schemas.BifrostError {
|
||||
if realtimeErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
statusCode := 500
|
||||
values := []string{
|
||||
strings.TrimSpace(realtimeErr.Type),
|
||||
strings.TrimSpace(realtimeErr.Code),
|
||||
strings.TrimSpace(realtimeErr.Message),
|
||||
}
|
||||
for _, value := range values {
|
||||
lower := strings.ToLower(value)
|
||||
switch {
|
||||
case lower == "":
|
||||
continue
|
||||
case strings.Contains(lower, "invalid_request_error"):
|
||||
statusCode = 400
|
||||
case isBudgetOrBillingError(lower), bifrost.IsRateLimitErrorMessage(lower):
|
||||
statusCode = 429
|
||||
}
|
||||
}
|
||||
|
||||
errType := strings.TrimSpace(realtimeErr.Type)
|
||||
if errType == "" {
|
||||
errType = "server_error"
|
||||
}
|
||||
errCode := strings.TrimSpace(realtimeErr.Code)
|
||||
if errCode == "" {
|
||||
errCode = errType
|
||||
}
|
||||
message := strings.TrimSpace(realtimeErr.Message)
|
||||
if message == "" {
|
||||
message = "realtime turn failed"
|
||||
}
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: schemas.Ptr(statusCode),
|
||||
Type: schemas.Ptr(errType),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errType),
|
||||
Code: schemas.Ptr(errCode),
|
||||
Message: message,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
Provider: provider,
|
||||
OriginalModelRequested: model,
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(realtimeErr.Param) != "" {
|
||||
bifrostErr.Error.Param = realtimeErr.Param
|
||||
}
|
||||
if len(rawResponse) > 0 {
|
||||
bifrostErr.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...))
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func completeRealtimeTurnTrace(ctx *schemas.BifrostContext) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
||||
if strings.TrimSpace(traceID) == "" {
|
||||
return
|
||||
}
|
||||
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
||||
if tracer == nil {
|
||||
return
|
||||
}
|
||||
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
|
||||
}
|
||||
|
||||
func finalizeRealtimeTurnHooksOnTransportError(
|
||||
client *bifrost.Bifrost,
|
||||
baseCtx *schemas.BifrostContext,
|
||||
session *bfws.Session,
|
||||
provider schemas.ModelProvider,
|
||||
model string,
|
||||
key *schemas.Key,
|
||||
status int,
|
||||
code string,
|
||||
message string,
|
||||
) *schemas.BifrostError {
|
||||
return finalizeRealtimeTurnHooksWithError(
|
||||
client,
|
||||
baseCtx,
|
||||
session,
|
||||
provider,
|
||||
model,
|
||||
key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(status, code, message),
|
||||
)
|
||||
}
|
||||
230
transports/bifrost-http/handlers/session.go
Normal file
230
transports/bifrost-http/handlers/session.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
"github.com/maximhq/bifrost/framework/encrypt"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// SessionHandler manages HTTP requests for session operations
|
||||
type SessionHandler struct {
|
||||
configStore configstore.ConfigStore
|
||||
wsTicketStore *WSTicketStore
|
||||
}
|
||||
|
||||
// NewSessionHandler creates a new session handler instance
|
||||
func NewSessionHandler(configStore configstore.ConfigStore, wsTicketStore *WSTicketStore) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
configStore: configStore,
|
||||
wsTicketStore: wsTicketStore,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the session-related routes
|
||||
func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.POST("/api/session/login", lib.ChainMiddlewares(h.login, middlewares...))
|
||||
r.POST("/api/session/logout", lib.ChainMiddlewares(h.logout, middlewares...))
|
||||
r.GET("/api/session/is-auth-enabled", lib.ChainMiddlewares(h.isAuthEnabled, middlewares...))
|
||||
r.POST("/api/session/ws-ticket", lib.ChainMiddlewares(h.issueWSTicket, middlewares...))
|
||||
}
|
||||
|
||||
// isAuthEnabled handles GET /api/session/is-auth-enabled - Check if auth is enabled
|
||||
func (h *SessionHandler) isAuthEnabled(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendJSON(ctx, map[string]any{
|
||||
"is_auth_enabled": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
authConfig, err := h.configStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth config: %v", err))
|
||||
return
|
||||
}
|
||||
if authConfig == nil {
|
||||
SendJSON(ctx, map[string]any{
|
||||
"is_auth_enabled": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Check if the header has a token and is valid (Authorization header or cookie)
|
||||
token := ""
|
||||
if authHeader := string(ctx.Request.Header.Peek("Authorization")); strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
if token == "" {
|
||||
token = string(ctx.Request.Header.Cookie("token"))
|
||||
}
|
||||
hasValidToken := false
|
||||
if token != "" {
|
||||
session, err := h.configStore.GetSession(ctx, token)
|
||||
if err == nil && session != nil && session.ExpiresAt.After(time.Now()) {
|
||||
hasValidToken = true
|
||||
}
|
||||
}
|
||||
SendJSON(ctx, map[string]any{
|
||||
"is_auth_enabled": authConfig.IsEnabled,
|
||||
"has_valid_token": hasValidToken,
|
||||
})
|
||||
}
|
||||
|
||||
// login handles POST /api/session/login - Login a user
|
||||
func (h *SessionHandler) login(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
|
||||
return
|
||||
}
|
||||
payload := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{}
|
||||
if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil {
|
||||
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get auth config
|
||||
authConfig, err := h.configStore.GetAuthConfig(ctx)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth config: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if auth is enabled
|
||||
if authConfig == nil || !authConfig.IsEnabled {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify credentials
|
||||
if payload.Username != authConfig.AdminUserName.GetValue() {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
compare, err := encrypt.CompareHash(authConfig.AdminPassword.GetValue(), payload.Password)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
if !compare {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
// Creating a new session
|
||||
token := uuid.New().String()
|
||||
session := &tables.SessionsTable{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24 * 30), // 30 days
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
err = h.configStore.CreateSession(ctx, session)
|
||||
if err != nil {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create session: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Setting cookies
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
cookie.SetKey("token")
|
||||
cookie.SetValue(token)
|
||||
cookie.SetExpire(time.Now().Add(time.Hour * 24 * 30))
|
||||
cookie.SetPath("/")
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
// Check if source is https then set secure
|
||||
if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Login successful",
|
||||
})
|
||||
}
|
||||
|
||||
// logout handles POST /api/session/logout - Logout a user
|
||||
func (h *SessionHandler) logout(ctx *fasthttp.RequestCtx) {
|
||||
if h.configStore == nil {
|
||||
SendError(ctx, fasthttp.StatusForbidden, "Authentication is not enabled")
|
||||
return
|
||||
}
|
||||
// Get token from Authorization header
|
||||
token := string(ctx.Request.Header.Peek("Authorization"))
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
|
||||
// If no token in header, try to get from cookie
|
||||
if token == "" {
|
||||
token = string(ctx.Request.Header.Cookie("token"))
|
||||
}
|
||||
|
||||
// clear token from cookies
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
cookie.SetKey("token")
|
||||
cookie.SetValue("")
|
||||
cookie.SetExpire(time.Now().Add(-time.Hour * 24 * 30))
|
||||
cookie.SetPath("/")
|
||||
cookie.SetHTTPOnly(true)
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
// Check if source is https then set secure
|
||||
if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
|
||||
// delete session from database if token exists
|
||||
if token != "" {
|
||||
err := h.configStore.DeleteSession(ctx, token)
|
||||
if err != nil && !errors.Is(err, configstore.ErrNotFound) {
|
||||
logger.Error("failed to delete session during logout: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to invalidate session. Please try again.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
SendJSON(ctx, map[string]any{
|
||||
"message": "Logout successful",
|
||||
})
|
||||
}
|
||||
|
||||
// issueWSTicket handles POST /api/session/ws-ticket - Issue a short-lived ticket for WebSocket auth.
|
||||
// The caller must already be authenticated (via cookie or Authorization header).
|
||||
// Returns a one-time-use ticket that the frontend passes as ?ticket= when opening the WebSocket.
|
||||
func (h *SessionHandler) issueWSTicket(ctx *fasthttp.RequestCtx) {
|
||||
if h.wsTicketStore == nil {
|
||||
SendError(ctx, fasthttp.StatusServiceUnavailable, "WebSocket tickets are not available")
|
||||
return
|
||||
}
|
||||
sessionToken,ok := ctx.UserValue(schemas.BifrostContextKeySessionToken).(string)
|
||||
if !ok {
|
||||
SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
if sessionToken == "" {
|
||||
// This is the case where auth is not configured or not enabled
|
||||
sessionToken = "dummy-session"
|
||||
}
|
||||
ticket, err := h.wsTicketStore.Issue(sessionToken)
|
||||
if err != nil {
|
||||
logger.Error("failed to issue WS ticket: %v", err)
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Failed to issue WebSocket ticket")
|
||||
return
|
||||
}
|
||||
SendJSON(ctx, map[string]any{
|
||||
"ticket": ticket,
|
||||
})
|
||||
}
|
||||
127
transports/bifrost-http/handlers/ssestreaming_test.go
Normal file
127
transports/bifrost-http/handlers/ssestreaming_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestSSEStreamReaderNoEventBatching verifies that SSE events are delivered
|
||||
// individually through fasthttp's chunked transfer encoding, not batched
|
||||
// into larger TCP segments. This is the core regression test for the
|
||||
// fasthttputil.PipeConns batching fix.
|
||||
func TestSSEStreamReaderNoEventBatching(t *testing.T) {
|
||||
const numEvents = 20
|
||||
|
||||
// Build expected events
|
||||
events := make([]string, numEvents)
|
||||
for i := range events {
|
||||
events[i] = fmt.Sprintf("data: {\"index\":%d,\"content\":\"chunk-%d\"}\n\n", i, i)
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetContentType("text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
reader := lib.NewSSEStreamReader()
|
||||
|
||||
go func() {
|
||||
defer reader.Done()
|
||||
for _, event := range events {
|
||||
if !reader.Send([]byte(event)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ctx.Response.SetBodyStream(reader, -1)
|
||||
}
|
||||
|
||||
// Use net.Pipe for deterministic in-process testing
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
// Run fasthttp server on one end of the pipe
|
||||
go func() {
|
||||
_ = fasthttp.ServeConn(serverConn, handler)
|
||||
}()
|
||||
|
||||
// Send HTTP request through the pipe
|
||||
_, err := clientConn.Write([]byte("GET /stream HTTP/1.1\r\nHost: test\r\n\r\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write request: %v", err)
|
||||
}
|
||||
|
||||
// Read response using bufio to parse chunked encoding
|
||||
br := bufio.NewReader(clientConn)
|
||||
|
||||
// Read and skip HTTP response headers
|
||||
for {
|
||||
line, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response header: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
break // End of headers
|
||||
}
|
||||
}
|
||||
|
||||
// Read chunked transfer-encoded body.
|
||||
// Each HTTP chunk should contain exactly one SSE event.
|
||||
var receivedEvents []string
|
||||
for {
|
||||
// Read chunk size line (hex size + CRLF)
|
||||
sizeLine, err := br.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read chunk size: %v", err)
|
||||
}
|
||||
sizeLine = strings.TrimSpace(sizeLine)
|
||||
|
||||
var chunkSize int
|
||||
_, err = fmt.Sscanf(sizeLine, "%x", &chunkSize)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse chunk size %q: %v", sizeLine, err)
|
||||
}
|
||||
|
||||
if chunkSize == 0 {
|
||||
break // Terminal chunk
|
||||
}
|
||||
|
||||
// Read exactly chunkSize bytes + trailing CRLF
|
||||
chunkData := make([]byte, chunkSize+2) // +2 for CRLF
|
||||
n := 0
|
||||
for n < len(chunkData) {
|
||||
nn, err := br.Read(chunkData[n:])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read chunk data: %v", err)
|
||||
}
|
||||
n += nn
|
||||
}
|
||||
|
||||
chunk := string(chunkData[:chunkSize])
|
||||
receivedEvents = append(receivedEvents, chunk)
|
||||
}
|
||||
|
||||
// Verify each chunk contains exactly one SSE event
|
||||
if len(receivedEvents) != numEvents {
|
||||
t.Errorf("expected %d individual chunks, got %d (events were batched)", numEvents, len(receivedEvents))
|
||||
for i, chunk := range receivedEvents {
|
||||
eventCount := strings.Count(chunk, "\n\n")
|
||||
t.Logf(" chunk %d: %d SSE events, %d bytes", i, eventCount, len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
for i, chunk := range receivedEvents {
|
||||
if i >= len(events) {
|
||||
break
|
||||
}
|
||||
if chunk != events[i] {
|
||||
t.Errorf("chunk %d: got %q, want %q", i, chunk, events[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
136
transports/bifrost-http/handlers/ui.go
Normal file
136
transports/bifrost-http/handlers/ui.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"mime"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// UIHandler handles UI routes.
|
||||
type UIHandler struct {
|
||||
uiContent embed.FS
|
||||
}
|
||||
|
||||
// NewUIHandler creates a new UIHandler instance.
|
||||
func NewUIHandler(uiContent embed.FS) *UIHandler {
|
||||
return &UIHandler{
|
||||
uiContent: uiContent,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the UI routes with the provided router.
|
||||
func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
router.GET("/", lib.ChainMiddlewares(h.serveDashboard, middlewares...))
|
||||
router.GET("/{filepath:*}", lib.ChainMiddlewares(h.serveDashboard, middlewares...))
|
||||
}
|
||||
|
||||
// ServeDashboard serves the dashboard UI.
|
||||
func (h *UIHandler) serveDashboard(ctx *fasthttp.RequestCtx) {
|
||||
// Get the request path
|
||||
requestPath := string(ctx.Path())
|
||||
|
||||
// Clean the path to prevent directory traversal
|
||||
cleanPath := path.Clean(requestPath)
|
||||
|
||||
// Handle .txt files - map from /{page}.txt to /{page}/index.txt
|
||||
if strings.HasSuffix(cleanPath, ".txt") {
|
||||
// Remove .txt extension and add /index.txt
|
||||
basePath := strings.TrimSuffix(cleanPath, ".txt")
|
||||
if basePath == "/" || basePath == "" {
|
||||
basePath = "/index"
|
||||
}
|
||||
cleanPath = basePath + "/index.txt"
|
||||
}
|
||||
|
||||
// Remove leading slash and add ui prefix
|
||||
if cleanPath == "/" {
|
||||
cleanPath = "ui/index.html"
|
||||
} else {
|
||||
cleanPath = "ui" + cleanPath
|
||||
}
|
||||
|
||||
// Block hidden directories and files (any path segment starting with .)
|
||||
segments := strings.Split(cleanPath, "/")
|
||||
for _, segment := range segments {
|
||||
if strings.HasPrefix(segment, ".") {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - Not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Block sensitive files
|
||||
baseName := filepath.Base(cleanPath)
|
||||
sensitiveFiles := []string{"package.json", "package-lock.json"}
|
||||
for _, sensitive := range sensitiveFiles {
|
||||
if baseName == sensitive {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - Not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a static asset request (has file extension)
|
||||
hasExtension := strings.Contains(filepath.Base(cleanPath), ".")
|
||||
|
||||
// Try to read the file from embedded filesystem
|
||||
data, err := h.uiContent.ReadFile(cleanPath)
|
||||
if err != nil {
|
||||
|
||||
// If it's a static asset (has extension) and not found, return 404
|
||||
if hasExtension {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - Static asset not found: " + requestPath)
|
||||
return
|
||||
}
|
||||
|
||||
// For routes without extensions (SPA routing), try {path}/index.html first
|
||||
if !hasExtension {
|
||||
indexPath := cleanPath + "/index.html"
|
||||
data, err = h.uiContent.ReadFile(indexPath)
|
||||
if err == nil {
|
||||
cleanPath = indexPath
|
||||
} else {
|
||||
// If that fails, serve root index.html as fallback
|
||||
data, err = h.uiContent.ReadFile("ui/index.html")
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - File not found")
|
||||
return
|
||||
}
|
||||
cleanPath = "ui/index.html"
|
||||
}
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 - File not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set content type based on file extension
|
||||
ext := filepath.Ext(cleanPath)
|
||||
contentType := mime.TypeByExtension(ext)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
ctx.SetContentType(contentType)
|
||||
|
||||
// Set cache headers for static assets
|
||||
if strings.HasPrefix(cleanPath, "ui/assets/") {
|
||||
ctx.Response.Header.Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
} else if ext == ".html" {
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
} else {
|
||||
ctx.Response.Header.Set("Cache-Control", "public, max-age=3600")
|
||||
}
|
||||
|
||||
// Send the file content
|
||||
ctx.SetBody(data)
|
||||
}
|
||||
225
transports/bifrost-http/handlers/utils.go
Normal file
225
transports/bifrost-http/handlers/utils.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains common utility functions used across all handlers.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// pluginDisabledKey is a dedicated context key type for marking a plugin as disabled
|
||||
// rather than removed. Using a named type instead of a raw string follows Go best practices.
|
||||
type pluginDisabledKey struct{}
|
||||
|
||||
// PluginDisabledKey is the context key used to indicate a plugin is being disabled.
|
||||
var PluginDisabledKey pluginDisabledKey
|
||||
|
||||
// badRequestError wraps a client input validation error so that outer handlers
|
||||
// can distinguish it from internal server errors and return HTTP 400.
|
||||
type badRequestError struct{ err error }
|
||||
|
||||
func (e *badRequestError) Error() string { return e.err.Error() }
|
||||
func (e *badRequestError) Unwrap() error { return e.err }
|
||||
|
||||
// SendJSON sends a JSON response with 200 OK status
|
||||
func SendJSON(ctx *fasthttp.RequestCtx, data interface{}) {
|
||||
ctx.SetContentType("application/json")
|
||||
if err := json.NewEncoder(ctx).Encode(data); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendJSONWithStatus sends a JSON response with a custom status code
|
||||
func SendJSONWithStatus(ctx *fasthttp.RequestCtx, data interface{}, statusCode int) {
|
||||
ctx.SetContentType("application/json")
|
||||
ctx.SetStatusCode(statusCode)
|
||||
if err := json.NewEncoder(ctx).Encode(data); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to encode JSON response: %v", err))
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// SendError sends a BifrostError response
|
||||
func SendError(ctx *fasthttp.RequestCtx, statusCode int, message string) {
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: &statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
SendBifrostError(ctx, bifrostErr)
|
||||
}
|
||||
|
||||
// SendBifrostError sends a BifrostError response
|
||||
func SendBifrostError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
|
||||
if bifrostErr.StatusCode != nil {
|
||||
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
||||
} else if !bifrostErr.IsBifrostError {
|
||||
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
|
||||
ctx.SetContentType("application/json")
|
||||
if encodeErr := json.NewEncoder(ctx).Encode(bifrostErr); encodeErr != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString(fmt.Sprintf("Failed to encode error response: %v", encodeErr))
|
||||
}
|
||||
}
|
||||
|
||||
// streamLargeResponseIfActive checks if large response mode was activated by the provider
|
||||
// and streams the response directly to the client. Returns true if the response was handled
|
||||
// (caller should return), false if normal response handling should continue.
|
||||
func streamLargeResponseIfActive(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
|
||||
if !ok || !isLargeResponse {
|
||||
return false
|
||||
}
|
||||
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
|
||||
SendError(ctx, fasthttp.StatusInternalServerError, "Large response reader not available")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SendSSEError sends an error in Server-Sent Events format
|
||||
func SendSSEError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) {
|
||||
errorJSON, err := json.Marshal(map[string]interface{}{
|
||||
"error": bifrostErr,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal error for SSE: %v", err)
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(ctx, "data: %s\n\n", errorJSON); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to write SSE error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// IsOriginAllowed checks if the given origin is allowed based on localhost rules and configured allowed origins.
|
||||
// Localhost origins are always allowed. Additional origins can be configured in allowedOrigins.
|
||||
// Supports wildcard patterns like *.example.com to match any subdomain.
|
||||
func IsOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
// Always allow localhost origins
|
||||
if isLocalhostOrigin(origin) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check configured allowed origins
|
||||
for _, allowedOrigin := range allowedOrigins {
|
||||
// Check for exact match first
|
||||
if allowedOrigin == origin {
|
||||
return true
|
||||
}
|
||||
|
||||
if allowedOrigin == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for wildcard pattern
|
||||
if strings.Contains(allowedOrigin, "*") {
|
||||
if matchesWildcardPattern(origin, allowedOrigin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLocalhostOrigin checks if the given origin is a localhost origin
|
||||
func isLocalhostOrigin(origin string) bool {
|
||||
return strings.HasPrefix(origin, "http://localhost:") ||
|
||||
strings.HasPrefix(origin, "https://localhost:") ||
|
||||
strings.HasPrefix(origin, "http://127.0.0.1:") ||
|
||||
strings.HasPrefix(origin, "http://0.0.0.0:") ||
|
||||
strings.HasPrefix(origin, "https://127.0.0.1:")
|
||||
}
|
||||
|
||||
// matchesWildcardPattern checks if an origin matches a wildcard pattern.
|
||||
// Supports patterns like *.example.com, https://*.example.com, or http://*.example.com
|
||||
func matchesWildcardPattern(origin string, pattern string) bool {
|
||||
// Convert wildcard pattern to regex pattern
|
||||
// Escape special regex characters except *
|
||||
regexPattern := regexp.QuoteMeta(pattern)
|
||||
// Replace escaped \* with regex pattern for subdomain matching
|
||||
// \* should match one or more characters that are not dots (to match a subdomain)
|
||||
regexPattern = strings.ReplaceAll(regexPattern, `\*`, `[^/.]+`)
|
||||
// Anchor the pattern to match the entire origin
|
||||
regexPattern = "^" + regexPattern + "$"
|
||||
|
||||
// Compile and test the regex
|
||||
re, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return re.MatchString(origin)
|
||||
}
|
||||
|
||||
// ParseModel parses a model string in the format "provider/model" or "provider/nested/model"
|
||||
// Returns the provider and full model name after the first slash
|
||||
func ParseModel(model string) (string, string, error) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return "", "", fmt.Errorf("model cannot be empty")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(model, "/", 2)
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("model must be in the format 'provider/model'")
|
||||
}
|
||||
|
||||
provider := strings.TrimSpace(parts[0])
|
||||
name := strings.TrimSpace(parts[1])
|
||||
if provider == "" || name == "" {
|
||||
return "", "", fmt.Errorf("model must be in the format 'provider/model' with non-empty provider and model")
|
||||
}
|
||||
return provider, name, nil
|
||||
}
|
||||
|
||||
// ClampPaginationParams applies default/max bounds to limit and offset so that
|
||||
// the handler response matches the values the store actually uses.
|
||||
func ClampPaginationParams(limit, offset int) (int, int) {
|
||||
if limit <= 0 {
|
||||
limit = 25
|
||||
} else if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
return limit, offset
|
||||
}
|
||||
|
||||
// fuzzyMatch checks if all characters in query appear in text in order (case-insensitive)
|
||||
// Example: "gpt4" matches "gpt-4", "gpt-4-turbo", etc.
|
||||
func fuzzyMatch(text, query string) bool {
|
||||
if query == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
text = strings.ToLower(text)
|
||||
query = strings.ToLower(query)
|
||||
|
||||
queryIndex := 0
|
||||
queryRunes := []rune(query)
|
||||
|
||||
for _, textChar := range text {
|
||||
if queryIndex < len(queryRunes) && textChar == queryRunes[queryIndex] {
|
||||
queryIndex++
|
||||
}
|
||||
}
|
||||
|
||||
return queryIndex == len(queryRunes)
|
||||
}
|
||||
1216
transports/bifrost-http/handlers/webrtc_realtime.go
Normal file
1216
transports/bifrost-http/handlers/webrtc_realtime.go
Normal file
File diff suppressed because it is too large
Load Diff
346
transports/bifrost-http/handlers/webrtc_realtime_test.go
Normal file
346
transports/bifrost-http/handlers/webrtc_realtime_test.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type testHandlerStore struct {
|
||||
kv *kvstore.Store
|
||||
}
|
||||
|
||||
func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true }
|
||||
func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil }
|
||||
func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil }
|
||||
func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
|
||||
return nil
|
||||
}
|
||||
func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil }
|
||||
func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 }
|
||||
func (s testHandlerStore) GetKVStore() *kvstore.Store { return s.kv }
|
||||
func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil }
|
||||
|
||||
func TestResolveRealtimeSDPTarget_BaseRouteRequiresProviderPrefix(t *testing.T) {
|
||||
_, _, _, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`))
|
||||
if err == nil {
|
||||
t.Fatal("expected provider/model validation error")
|
||||
}
|
||||
if err.Error == nil || err.Error.Message != "session.model must use provider/model on /v1 realtime routes" {
|
||||
t.Fatalf("unexpected error: %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeSDPTarget_BaseRouteNormalizesModel(t *testing.T) {
|
||||
provider, model, normalized, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if provider != schemas.OpenAI {
|
||||
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("unexpected normalized model: %s", model)
|
||||
}
|
||||
|
||||
var root map[string]json.RawMessage
|
||||
if unmarshalErr := json.Unmarshal(normalized, &root); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal normalized session: %v", unmarshalErr)
|
||||
}
|
||||
var sessionModel string
|
||||
if unmarshalErr := json.Unmarshal(root["model"], &sessionModel); unmarshalErr != nil {
|
||||
t.Fatalf("failed to unmarshal model: %v", unmarshalErr)
|
||||
}
|
||||
if sessionModel != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("unexpected marshaled model: %s", sessionModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeSDPTarget_OpenAIRouteDefaultsProvider(t *testing.T) {
|
||||
provider, model, _, err := resolveRealtimeSDPTarget("/openai/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if provider != schemas.OpenAI {
|
||||
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("unexpected model: %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCallsWebRTCRequest_RawSDPKeepsGARoute(t *testing.T) {
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetRequestURI("/openai/v1/realtime/calls?model=gpt-realtime")
|
||||
ctx.Request.Header.SetContentType("application/sdp")
|
||||
ctx.Request.SetBodyString("v=0\r\n")
|
||||
|
||||
sdpOffer, provider, model, session, err := parseCallsWebRTCRequest(&ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if sdpOffer != "v=0\r\n" {
|
||||
t.Fatalf("unexpected sdp offer: %q", sdpOffer)
|
||||
}
|
||||
if provider != schemas.OpenAI {
|
||||
t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider)
|
||||
}
|
||||
if model != "gpt-realtime" {
|
||||
t.Fatalf("unexpected model: %s", model)
|
||||
}
|
||||
if session != nil {
|
||||
t.Fatalf("expected nil session for raw SDP /calls request, got %s", string(session))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRealtimeRelayContextCopiesValuesWithoutRequestCancellation(t *testing.T) {
|
||||
requestCtx, requestCancel := schemas.NewBifrostContextWithCancel(context.Background())
|
||||
requestCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
requestCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
|
||||
requestCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, "vk_test")
|
||||
|
||||
relayCtx, relayCancel := newRealtimeRelayContext(requestCtx)
|
||||
defer relayCancel()
|
||||
|
||||
requestCancel()
|
||||
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected request context to be cancelled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
t.Fatal("relay context should outlive cancelled request context")
|
||||
default:
|
||||
}
|
||||
|
||||
if got := relayCtx.Value(schemas.BifrostContextKeyHTTPRequestType); got != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %v, want %v", got, schemas.RealtimeRequest)
|
||||
}
|
||||
if got := relayCtx.Value(schemas.BifrostContextKeyIntegrationType); got != "openai" {
|
||||
t.Fatalf("integration type = %v, want %q", got, "openai")
|
||||
}
|
||||
if got := relayCtx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID); got != "vk_test" {
|
||||
t.Fatalf("virtual key id = %v, want %q", got, "vk_test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEventPreservesExtraParams(t *testing.T) {
|
||||
event, err := schemas.ParseRealtimeEvent([]byte(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var itemID string
|
||||
if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil {
|
||||
t.Fatalf("json.Unmarshal(item_id) error = %v", err)
|
||||
}
|
||||
if itemID != "item_123" {
|
||||
t.Fatalf("item_id = %q, want %q", itemID, "item_123")
|
||||
}
|
||||
|
||||
var contentIndex int
|
||||
if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil {
|
||||
t.Fatalf("json.Unmarshal(content_index) error = %v", err)
|
||||
}
|
||||
if contentIndex != 0 {
|
||||
t.Fatalf("content_index = %d, want 0", contentIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRealtimeBearerToken(t *testing.T) {
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Request.Header.Set("Authorization", "Bearer ek_test_123")
|
||||
|
||||
if got := extractRealtimeBearerToken(&ctx); got != "ek_test_123" {
|
||||
t.Fatalf("extractRealtimeBearerToken() = %q, want %q", got, "ek_test_123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupRealtimeEphemeralKeyMappingKeepsEntryUntilTTLExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
payload, err := json.Marshal(realtimeEphemeralKeyMapping{KeyID: "key_123", VirtualKey: "sk-bf-test"})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_123"), payload, time.Minute); err != nil {
|
||||
t.Fatalf("store.SetWithTTL() error = %v", err)
|
||||
}
|
||||
|
||||
mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_123")
|
||||
if !ok {
|
||||
t.Fatal("expected mapping to be consumed")
|
||||
}
|
||||
if mapping.KeyID != "key_123" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123")
|
||||
}
|
||||
if mapping.VirtualKey != "sk-bf-test" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test")
|
||||
}
|
||||
|
||||
raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_123"))
|
||||
if err != nil {
|
||||
t.Fatalf("expected mapping to remain until TTL expiry: %v", err)
|
||||
}
|
||||
if raw == nil {
|
||||
t.Fatal("expected mapping to remain in KV store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupRealtimeEphemeralKeyMapping_BackwardsCompatibleStringValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_legacy"), "key_legacy", time.Minute); err != nil {
|
||||
t.Fatalf("store.SetWithTTL() error = %v", err)
|
||||
}
|
||||
|
||||
mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_legacy")
|
||||
if !ok {
|
||||
t.Fatal("expected legacy mapping to be consumed")
|
||||
}
|
||||
if mapping.KeyID != "key_legacy" {
|
||||
t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_legacy")
|
||||
}
|
||||
if mapping.VirtualKey != "" {
|
||||
t.Fatalf("mapping.VirtualKey = %q, want empty", mapping.VirtualKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebRTCRealtimeRelayCloseFinalizesActiveTurnHooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := bfws.NewSession(nil)
|
||||
session.SetProviderSessionID("sess_provider_123")
|
||||
session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`)
|
||||
|
||||
var (
|
||||
capturedErr *schemas.BifrostError
|
||||
cleanedUp bool
|
||||
)
|
||||
session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{
|
||||
RequestID: "req_realtime_123",
|
||||
StartedAt: time.Now().Add(-time.Second),
|
||||
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
capturedErr = err
|
||||
return result, nil
|
||||
},
|
||||
Cleanup: func() {
|
||||
cleanedUp = true
|
||||
},
|
||||
})
|
||||
|
||||
relay := &webrtcRealtimeRelay{
|
||||
session: session,
|
||||
providerKey: schemas.OpenAI,
|
||||
model: "gpt-realtime",
|
||||
}
|
||||
|
||||
relay.close()
|
||||
|
||||
if capturedErr == nil {
|
||||
t.Fatal("expected active turn to be finalized with an error on close")
|
||||
}
|
||||
if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest {
|
||||
t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest)
|
||||
}
|
||||
if capturedErr.Error == nil || capturedErr.Error.Message != "realtime WebRTC session closed before turn completed" {
|
||||
t.Fatalf("error message = %#v, want realtime close message", capturedErr.Error)
|
||||
}
|
||||
if session.PeekRealtimeTurnHooks() != nil {
|
||||
t.Fatal("expected active realtime turn hooks to be cleared")
|
||||
}
|
||||
if !cleanedUp {
|
||||
t.Fatal("expected realtime hook cleanup to run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRealtimeWebRTCKeys_UnmappedEphemeralTokenStaysAnonymous(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := kvstore.New(kvstore.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("kvstore.New() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
handler := &WebRTCRealtimeHandler{
|
||||
handlerStore: testHandlerStore{kv: store},
|
||||
}
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Request.Header.Set("Authorization", "Bearer ek_test_unmapped")
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, schemas.Key{ID: "header-provided"})
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyID, "selected")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyName, "selected-name")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, "mapped-id")
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, "mapped-name")
|
||||
|
||||
authKey, selectedKey, err := handler.resolveRealtimeWebRTCKeys(&ctx, bifrostCtx, schemas.OpenAI, "gpt-realtime")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveRealtimeWebRTCKeys() error = %v", err)
|
||||
}
|
||||
if got := authKey.Value.GetValue(); got != "ek_test_unmapped" {
|
||||
t.Fatalf("auth key value = %q, want %q", got, "ek_test_unmapped")
|
||||
}
|
||||
if selectedKey != nil {
|
||||
t.Fatalf("selectedKey = %#v, want nil", selectedKey)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyDirectKey); got != nil {
|
||||
t.Fatalf("direct key context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyID); got != nil {
|
||||
t.Fatalf("selected key id context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyName); got != nil {
|
||||
t.Fatalf("selected key name context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != nil {
|
||||
t.Fatalf("api key id context = %#v, want nil", got)
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyName); got != nil {
|
||||
t.Fatalf("api key name context = %#v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyRealtimeEphemeralKeyMapping_RestoresVirtualKeyAndKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
applyRealtimeEphemeralKeyMapping(bifrostCtx, realtimeEphemeralKeyMapping{
|
||||
KeyID: "key_123",
|
||||
VirtualKey: "sk-bf-test",
|
||||
})
|
||||
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyVirtualKey); got != "sk-bf-test" {
|
||||
t.Fatalf("virtual key context = %#v, want %q", got, "sk-bf-test")
|
||||
}
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != "key_123" {
|
||||
t.Fatalf("api key id context = %#v, want %q", got, "key_123")
|
||||
}
|
||||
}
|
||||
268
transports/bifrost-http/handlers/websocket.go
Normal file
268
transports/bifrost-http/handlers/websocket.go
Normal file
@@ -0,0 +1,268 @@
|
||||
// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
|
||||
// This file contains WebSocket handlers for real-time log streaming.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// WebSocketClient represents a connected WebSocket client with its own mutex
|
||||
type WebSocketClient struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex // Per-connection mutex for thread-safe writes
|
||||
}
|
||||
|
||||
// WebSocketHandler manages WebSocket connections for real-time updates
|
||||
type WebSocketHandler struct {
|
||||
ctx context.Context
|
||||
allowedOrigins []string
|
||||
clients map[*websocket.Conn]*WebSocketClient
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{} // Channel to signal heartbeat goroutine to stop
|
||||
done chan struct{} // Channel to signal when heartbeat goroutine has stopped
|
||||
}
|
||||
|
||||
// NewWebSocketHandler creates a new WebSocket handler instance
|
||||
func NewWebSocketHandler(ctx context.Context, allowedOrigins []string) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
ctx: ctx,
|
||||
allowedOrigins: allowedOrigins,
|
||||
clients: make(map[*websocket.Conn]*WebSocketClient),
|
||||
stopChan: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all WebSocket-related routes
|
||||
func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...))
|
||||
}
|
||||
|
||||
// getUpgrader returns a WebSocket upgrader configured with the current allowed origins
|
||||
func (h *WebSocketHandler) getUpgrader() websocket.FastHTTPUpgrader {
|
||||
return websocket.FastHTTPUpgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
// If no Origin header, check the Host header for direct connections
|
||||
host := string(ctx.Request.Header.Peek("Host"))
|
||||
return isLocalhost(host)
|
||||
}
|
||||
// Check if origin is allowed (localhost always allowed + configured origins)
|
||||
return IsOriginAllowed(origin, h.allowedOrigins)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// isLocalhost checks if the given host is localhost
|
||||
func isLocalhost(host string) bool {
|
||||
// Remove port if present
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
return host == "localhost" ||
|
||||
host == "127.0.0.1" ||
|
||||
host == "::1" ||
|
||||
host == ""
|
||||
}
|
||||
|
||||
// connectStream handles WebSocket connections for real-time streaming
|
||||
func (h *WebSocketHandler) connectStream(ctx *fasthttp.RequestCtx) {
|
||||
upgrader := h.getUpgrader()
|
||||
err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) {
|
||||
// Read safety & liveness
|
||||
ws.SetReadLimit(50 << 20) // 50 MiB
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
// Create a new client with its own mutex
|
||||
client := &WebSocketClient{
|
||||
conn: ws,
|
||||
}
|
||||
|
||||
// Register new client
|
||||
h.mu.Lock()
|
||||
h.clients[ws] = client
|
||||
h.mu.Unlock()
|
||||
|
||||
// Clean up on disconnect
|
||||
defer func() {
|
||||
h.mu.Lock()
|
||||
delete(h.clients, ws)
|
||||
h.mu.Unlock()
|
||||
ws.Close()
|
||||
}()
|
||||
|
||||
// Keep connection alive and handle client messages
|
||||
// This loop continuously reads and discards incoming WebSocket messages to:
|
||||
// 1. Keep the connection alive by processing client pings and control frames
|
||||
// 2. Detect when the client disconnects by watching for close frames or errors
|
||||
// 3. Maintain proper WebSocket protocol handling without accumulating messages
|
||||
for {
|
||||
_, _, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
// Only log unexpected close errors
|
||||
if websocket.IsUnexpectedCloseError(err,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseAbnormalClosure,
|
||||
websocket.CloseNoStatusReceived) {
|
||||
logger.Error("websocket read error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("websocket upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// sendMessageSafely sends a message to a client with proper locking and error handling
|
||||
func (h *WebSocketHandler) sendMessageSafely(client *WebSocketClient, messageType int, data []byte) error {
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
|
||||
// Set a write deadline to prevent hanging connections
|
||||
client.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
defer client.conn.SetWriteDeadline(time.Time{}) // Clear the deadline
|
||||
|
||||
err := client.conn.WriteMessage(messageType, data)
|
||||
if err != nil {
|
||||
// Remove the client from the map if write fails
|
||||
go func() {
|
||||
h.mu.Lock()
|
||||
delete(h.clients, client.conn)
|
||||
h.mu.Unlock()
|
||||
client.conn.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// BroadcastUpdatesToClients sends a store update notification to all connected WebSocket clients
|
||||
// The tags parameter should match RTK Query tagTypes (e.g., "Providers", "VirtualKeys", "MCPClients")
|
||||
func (h *WebSocketHandler) BroadcastUpdatesToClients(tags []string) {
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Tags []string `json:"tags"`
|
||||
}{
|
||||
Type: "store_update",
|
||||
Tags: tags,
|
||||
}
|
||||
|
||||
data, err := sonic.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal store update: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.BroadcastMarshaledMessage(data)
|
||||
}
|
||||
|
||||
// BroadcastEvent sends a typed event to all connected WebSocket clients.
|
||||
// Any subsystem can use this to push real-time updates to the frontend.
|
||||
func (h *WebSocketHandler) BroadcastEvent(eventType string, data interface{}) {
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
}{
|
||||
Type: eventType,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
bytes, err := sonic.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal event %s: %v", eventType, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.BroadcastMarshaledMessage(bytes)
|
||||
}
|
||||
|
||||
// BroadcastMarshaledMessage sends an adaptive routing update to all connected WebSocket clients
|
||||
func (h *WebSocketHandler) BroadcastMarshaledMessage(data []byte) {
|
||||
// Get a snapshot of clients to avoid holding the lock during writes
|
||||
h.mu.RLock()
|
||||
clients := make([]*WebSocketClient, 0, len(h.clients))
|
||||
for _, client := range h.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Send message to each client safely
|
||||
for _, client := range clients {
|
||||
if err := h.sendMessageSafely(client, websocket.TextMessage, data); err != nil {
|
||||
logger.Error("failed to send message to client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartHeartbeat starts sending periodic heartbeat messages to keep connections alive
|
||||
func (h *WebSocketHandler) StartHeartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
close(h.done)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
logger.Info("got context cancel(), stopping webserver")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Get a snapshot of clients to avoid holding the lock during writes
|
||||
h.mu.RLock()
|
||||
clients := make([]*WebSocketClient, 0, len(h.clients))
|
||||
for _, client := range h.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Send heartbeat to each client safely
|
||||
for _, client := range clients {
|
||||
if err := h.sendMessageSafely(client, websocket.PingMessage, nil); err != nil {
|
||||
logger.Error("failed to send heartbeat: %v", err)
|
||||
}
|
||||
}
|
||||
case <-h.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the WebSocket handler
|
||||
func (h *WebSocketHandler) Stop() {
|
||||
close(h.stopChan) // Signal heartbeat goroutine to stop
|
||||
<-h.done // Wait for heartbeat goroutine to finish
|
||||
|
||||
// Close all client connections
|
||||
h.mu.Lock()
|
||||
for _, client := range h.clients {
|
||||
client.conn.Close()
|
||||
}
|
||||
h.clients = make(map[*websocket.Conn]*WebSocketClient)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
102
transports/bifrost-http/handlers/ws_ticket.go
Normal file
102
transports/bifrost-http/handlers/ws_ticket.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
wsTicketTTL = 30 * time.Second
|
||||
wsTicketCleanupHz = 60 * time.Second
|
||||
)
|
||||
|
||||
type wsTicketEntry struct {
|
||||
sessionToken string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// WSTicketStore provides short-lived, single-use tickets for WebSocket authentication.
|
||||
// Instead of putting the long-lived session token in the WS URL (visible in logs/history),
|
||||
// clients exchange their session for a 30-second one-time ticket via an authenticated endpoint.
|
||||
type WSTicketStore struct {
|
||||
mu sync.Mutex
|
||||
tickets map[string]wsTicketEntry
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewWSTicketStore creates a new ticket store and starts a background goroutine
|
||||
// that periodically purges expired tickets.
|
||||
func NewWSTicketStore() *WSTicketStore {
|
||||
s := &WSTicketStore{
|
||||
tickets: make(map[string]wsTicketEntry),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go s.cleanup()
|
||||
return s
|
||||
}
|
||||
|
||||
// Issue generates a cryptographically random ticket bound to the given session token.
|
||||
// The ticket expires after wsTicketTTL (30 seconds).
|
||||
func (s *WSTicketStore) Issue(sessionToken string) (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ticket := hex.EncodeToString(b)
|
||||
|
||||
s.mu.Lock()
|
||||
s.tickets[ticket] = wsTicketEntry{
|
||||
sessionToken: sessionToken,
|
||||
expiresAt: time.Now().Add(wsTicketTTL),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// Consume validates and deletes a ticket, returning the underlying session token.
|
||||
// Returns empty string if the ticket doesn't exist or has expired (single-use).
|
||||
func (s *WSTicketStore) Consume(ticket string) string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entry, ok := s.tickets[ticket]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
delete(s.tickets, ticket)
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return ""
|
||||
}
|
||||
return entry.sessionToken
|
||||
}
|
||||
|
||||
// Stop terminates the background cleanup goroutine.
|
||||
func (s *WSTicketStore) Stop() {
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.done)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup periodically removes expired tickets to prevent unbounded memory growth.
|
||||
func (s *WSTicketStore) cleanup() {
|
||||
ticker := time.NewTicker(wsTicketCleanupHz)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, v := range s.tickets {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(s.tickets, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
666
transports/bifrost-http/handlers/wsrealtime.go
Normal file
666
transports/bifrost-http/handlers/wsrealtime.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
realtimeWSPingInterval = 15 * time.Second
|
||||
realtimeWSPongTimeout = 45 * time.Second
|
||||
realtimeWSPingWriteTimeout = 10 * time.Second
|
||||
realtimeWSWriteTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API.
|
||||
type WSRealtimeHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
pool *bfws.Pool
|
||||
sessions *bfws.SessionManager
|
||||
}
|
||||
|
||||
// NewWSRealtimeHandler creates a new Realtime WebSocket handler.
|
||||
func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler {
|
||||
maxConns := config.WebSocketConfig.MaxConnections
|
||||
|
||||
return &WSRealtimeHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
pool: pool,
|
||||
sessions: bfws.NewSessionManager(maxConns),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths.
|
||||
func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
|
||||
r.GET("/v1/realtime", handler)
|
||||
for _, path := range integrations.OpenAIRealtimePaths("/openai") {
|
||||
r.GET(path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) Close() {
|
||||
if h == nil || h.sessions == nil {
|
||||
return
|
||||
}
|
||||
h.sessions.CloseAll()
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
modelParam := string(ctx.QueryArgs().Peek("model"))
|
||||
deploymentParam := string(ctx.QueryArgs().Peek("deployment"))
|
||||
auth := captureAuthHeaders(ctx)
|
||||
// OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.<key>".
|
||||
// Extract it into the auth headers so downstream processing recognizes it.
|
||||
if auth.authorization == "" {
|
||||
if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" {
|
||||
auth.authorization = "Bearer " + token
|
||||
}
|
||||
}
|
||||
|
||||
providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam)
|
||||
if err != nil {
|
||||
upgrader := h.websocketUpgrader("")
|
||||
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
clientConn := newRealtimeClientConn(conn)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
|
||||
})
|
||||
if upgradeErr != nil {
|
||||
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
provider := h.client.GetProviderByKey(providerKey)
|
||||
rtProvider, ok := provider.(schemas.RealtimeProvider)
|
||||
if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() {
|
||||
upgrader := h.websocketUpgrader("")
|
||||
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
clientConn := newRealtimeClientConn(conn)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
|
||||
})
|
||||
if upgradeErr != nil {
|
||||
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol())
|
||||
err = upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
clientConn := newRealtimeClientConn(conn)
|
||||
|
||||
session, sessionErr := h.sessions.Create(conn)
|
||||
if sessionErr != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error()))
|
||||
return
|
||||
}
|
||||
defer h.sessions.Remove(conn)
|
||||
|
||||
h.runRealtimeSession(clientConn, session, auth, path, providerKey, model)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("websocket upgrade failed for %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader {
|
||||
upgrader := ws.FastHTTPUpgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins)
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(subprotocol) != "" {
|
||||
upgrader.Subprotocols = []string{subprotocol}
|
||||
}
|
||||
return upgrader
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) runRealtimeSession(
|
||||
clientConn *realtimeClientConn,
|
||||
session *bfws.Session,
|
||||
auth *authHeaders,
|
||||
path string,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
) {
|
||||
clientConn.startHeartbeat()
|
||||
defer clientConn.stopHeartbeat()
|
||||
|
||||
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
|
||||
if bifrostCtx == nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context"))
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
// Resolve ephemeral key mapping to restore virtual key context.
|
||||
token := extractRealtimeBearerTokenFromHeader(auth.authorization)
|
||||
if isRealtimeEphemeralToken(token) {
|
||||
mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token)
|
||||
if ok {
|
||||
applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping)
|
||||
}
|
||||
}
|
||||
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
|
||||
if strings.HasPrefix(path, "/openai") {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
|
||||
}
|
||||
|
||||
provider := h.client.GetProviderByKey(providerKey)
|
||||
if provider == nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey)))
|
||||
return
|
||||
}
|
||||
|
||||
rtProvider, ok := provider.(schemas.RealtimeProvider)
|
||||
if !ok || !rtProvider.SupportsRealtimeAPI() {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve model alias so the provider receives the actual model identifier.
|
||||
model = key.Aliases.Resolve(model)
|
||||
|
||||
wsURL := rtProvider.RealtimeWebSocketURL(key, model)
|
||||
upstream, err := h.pool.Get(bfws.PoolKey{
|
||||
Provider: providerKey,
|
||||
KeyID: key.ID,
|
||||
Endpoint: wsURL,
|
||||
}, rtProvider.RealtimeHeaders(key))
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error()))
|
||||
return
|
||||
}
|
||||
defer h.pool.Discard(upstream)
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
|
||||
}()
|
||||
go func() {
|
||||
errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
|
||||
}()
|
||||
|
||||
firstErr := <-errCh
|
||||
_ = upstream.Close()
|
||||
_ = clientConn.Close()
|
||||
secondErr := <-errCh
|
||||
|
||||
if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil {
|
||||
logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) relayClientToRealtimeProvider(
|
||||
clientConn *realtimeClientConn,
|
||||
session *bfws.Session,
|
||||
upstream *bfws.UpstreamConn,
|
||||
provider schemas.RealtimeProvider,
|
||||
bifrostCtx *schemas.BifrostContext,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
key schemas.Key,
|
||||
) error {
|
||||
for {
|
||||
messageType, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
finalizeRealtimeTurnHooksOnTransportError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
499,
|
||||
"client_closed_request",
|
||||
"client realtime websocket disconnected before turn completed",
|
||||
)
|
||||
if isNormalWebSocketClosure(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if messageType != ws.TextMessage {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages"))
|
||||
return nil
|
||||
}
|
||||
|
||||
event, err := schemas.ParseRealtimeEvent(message)
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON"))
|
||||
continue
|
||||
}
|
||||
// Extract pending tool/input summaries but defer recording until the event
|
||||
// passes validation — rejected events must not pollute session state.
|
||||
toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event)
|
||||
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
|
||||
|
||||
startsTurn := provider.ShouldStartRealtimeTurn(event)
|
||||
if startsTurn {
|
||||
if session.PeekRealtimeTurnHooks() != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress."))
|
||||
continue
|
||||
}
|
||||
if toolSummary != "" {
|
||||
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
|
||||
}
|
||||
if inputSummary != "" {
|
||||
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
|
||||
}
|
||||
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
|
||||
clientConn.writeRealtimeError(bifrostErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
providerEvent, err := provider.ToProviderRealtimeEvent(event)
|
||||
if err != nil {
|
||||
if startsTurn {
|
||||
if finalizeErr := finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()),
|
||||
); finalizeErr != nil {
|
||||
clientConn.writeRealtimeError(finalizeErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
|
||||
continue
|
||||
}
|
||||
|
||||
// Record tool output / input only after the event passed validation.
|
||||
if !startsTurn {
|
||||
if toolSummary != "" {
|
||||
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
|
||||
}
|
||||
if inputSummary != "" {
|
||||
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
|
||||
}
|
||||
}
|
||||
|
||||
if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil {
|
||||
finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"),
|
||||
)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSRealtimeHandler) relayRealtimeProviderToClient(
|
||||
clientConn *realtimeClientConn,
|
||||
session *bfws.Session,
|
||||
upstream *bfws.UpstreamConn,
|
||||
provider schemas.RealtimeProvider,
|
||||
bifrostCtx *schemas.BifrostContext,
|
||||
providerKey schemas.ModelProvider,
|
||||
model string,
|
||||
key schemas.Key,
|
||||
) error {
|
||||
for {
|
||||
disconnectAfterWrite := false
|
||||
messageType, message, err := upstream.ReadMessage()
|
||||
if err != nil {
|
||||
finalizeRealtimeTurnHooksOnTransportError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
502,
|
||||
"upstream_connection_error",
|
||||
"upstream realtime websocket closed before turn completed",
|
||||
)
|
||||
if isNormalWebSocketClosure(err) {
|
||||
return nil
|
||||
}
|
||||
finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
nil,
|
||||
newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"),
|
||||
)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"))
|
||||
return err
|
||||
}
|
||||
|
||||
if messageType == ws.TextMessage {
|
||||
event, err := provider.ToBifrostRealtimeEvent(message)
|
||||
if err != nil {
|
||||
finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
schemas.RTEventError,
|
||||
message,
|
||||
newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"),
|
||||
)
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"))
|
||||
return err
|
||||
}
|
||||
if event != nil {
|
||||
if event.Session != nil && event.Session.ID != "" {
|
||||
session.SetProviderSessionID(event.Session.ID)
|
||||
}
|
||||
if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) {
|
||||
session.AppendRealtimeOutputText(event.Delta.Text)
|
||||
session.AppendRealtimeOutputText(event.Delta.Transcript)
|
||||
}
|
||||
if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil {
|
||||
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
|
||||
clientConn.writeRealtimeError(bifrostErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if event != nil {
|
||||
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
|
||||
if !provider.ShouldForwardRealtimeEvent(event) {
|
||||
continue
|
||||
}
|
||||
if event.Type == provider.RealtimeTurnFinalEvent() {
|
||||
contentOverride := session.ConsumeRealtimeOutputText()
|
||||
if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil {
|
||||
clientConn.writeRealtimeError(bifrostErr)
|
||||
return nil
|
||||
}
|
||||
} else if event.Error != nil {
|
||||
turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error)
|
||||
finalizeErr := finalizeRealtimeTurnHooksWithError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
event.Type,
|
||||
message,
|
||||
turnErr,
|
||||
)
|
||||
if finalizeErr != nil {
|
||||
clientConn.writeRealtimeError(finalizeErr)
|
||||
return nil
|
||||
}
|
||||
// Defer the disconnect so the normal translated-write path
|
||||
// below still runs — otherwise terminal errors from translated
|
||||
// providers would reach the client in provider-native format.
|
||||
disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr)
|
||||
} else if inputSummary != "" {
|
||||
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
|
||||
}
|
||||
if len(event.RawData) == 0 {
|
||||
message, err = provider.ToProviderRealtimeEvent(event)
|
||||
if err != nil {
|
||||
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event"))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := clientConn.WriteMessage(messageType, message); err != nil {
|
||||
finalizeRealtimeTurnHooksOnTransportError(
|
||||
h.client,
|
||||
bifrostCtx,
|
||||
session,
|
||||
providerKey,
|
||||
model,
|
||||
&key,
|
||||
499,
|
||||
"client_closed_request",
|
||||
"client realtime websocket disconnected before turn completed",
|
||||
)
|
||||
if isNormalWebSocketClosure(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if disconnectAfterWrite {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) {
|
||||
defaultProvider := realtimeDefaultProviderForPath(path)
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(modelParam) != "":
|
||||
provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider)
|
||||
if provider == "" || strings.TrimSpace(model) == "" {
|
||||
return "", "", errRealtimeModelFormat
|
||||
}
|
||||
return provider, strings.TrimSpace(model), nil
|
||||
case strings.TrimSpace(deploymentParam) != "":
|
||||
provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider)
|
||||
if provider == "" || strings.TrimSpace(model) == "" {
|
||||
return "", "", errRealtimeDeploymentFormat
|
||||
}
|
||||
return provider, strings.TrimSpace(model), nil
|
||||
default:
|
||||
return "", "", errRealtimeModelRequired
|
||||
}
|
||||
}
|
||||
|
||||
func realtimeDefaultProviderForPath(path string) schemas.ModelProvider {
|
||||
if strings.HasPrefix(path, "/openai/") {
|
||||
return schemas.OpenAI
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isNormalWebSocketClosure(err error) bool {
|
||||
return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived)
|
||||
}
|
||||
|
||||
func isExpectedRealtimeRelayShutdown(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
// Relay teardown closes the opposite socket after the first side exits, which can
|
||||
// surface as a plain network-close read error instead of a websocket close frame.
|
||||
return strings.Contains(err.Error(), "use of closed network connection")
|
||||
}
|
||||
|
||||
func selectRealtimeRelayError(errs ...error) error {
|
||||
for _, err := range errs {
|
||||
if err != nil && !isExpectedRealtimeRelayShutdown(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket")
|
||||
errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket")
|
||||
errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket")
|
||||
)
|
||||
|
||||
type realtimeClientConn struct {
|
||||
conn *ws.Conn
|
||||
writeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn {
|
||||
return &realtimeClientConn{
|
||||
conn: conn,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) {
|
||||
messageType, p, err = c.conn.ReadMessage()
|
||||
if err == nil {
|
||||
c.refreshReadDeadline()
|
||||
}
|
||||
return messageType, p, err
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.conn.WriteMessage(messageType, data); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) startHeartbeat() {
|
||||
c.installPongHandler()
|
||||
c.refreshReadDeadline()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(realtimeWSPingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.writePing(); err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) stopHeartbeat() {
|
||||
c.closeDone()
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) installPongHandler() {
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
return c.refreshReadDeadline()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) refreshReadDeadline() error {
|
||||
return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout))
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) writePing() error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) closeDone() {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.done)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) {
|
||||
payload := newRealtimeTurnErrorEventPayload(bifrostErr)
|
||||
_ = c.WriteMessage(ws.TextMessage, payload)
|
||||
}
|
||||
|
||||
func (c *realtimeClientConn) Close() error {
|
||||
c.closeDone()
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key."
|
||||
|
||||
// extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol
|
||||
// header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.<key>".
|
||||
func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string {
|
||||
header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol"))
|
||||
for _, proto := range strings.Split(header, ",") {
|
||||
proto = strings.TrimSpace(proto)
|
||||
if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) {
|
||||
return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError {
|
||||
errType := code
|
||||
return &schemas.BifrostError{
|
||||
StatusCode: &status,
|
||||
Type: &errType,
|
||||
Error: &schemas.ErrorField{
|
||||
Type: &errType,
|
||||
Code: &errType,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
702
transports/bifrost-http/handlers/wsresponses.go
Normal file
702
transports/bifrost-http/handlers/wsresponses.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// wsWriter abstracts a WebSocket write target. Both *ws.Conn (pre-session)
|
||||
// and *bfws.Session (post-session, mutex-protected) satisfy this interface.
|
||||
type wsWriter interface {
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
}
|
||||
|
||||
// WSResponsesHandler handles WebSocket connections for the Responses API WebSocket Mode.
|
||||
// Clients connect via `GET /v1/responses` with a WS upgrade and send `response.create` events.
|
||||
// Each event is routed through the standard Bifrost inference pipeline (PreLLMHook, key selection,
|
||||
// provider call, PostLLMHook) via the HTTP bridge, with native WS upstream as an optimization.
|
||||
type WSResponsesHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
pool *bfws.Pool
|
||||
sessions *bfws.SessionManager
|
||||
upgrader ws.FastHTTPUpgrader
|
||||
}
|
||||
|
||||
// NewWSResponsesHandler creates a new WebSocket Responses handler.
|
||||
func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSResponsesHandler {
|
||||
maxConns := config.WebSocketConfig.MaxConnections
|
||||
|
||||
return &WSResponsesHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
pool: pool,
|
||||
sessions: bfws.NewSessionManager(maxConns),
|
||||
upgrader: ws.FastHTTPUpgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
return IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all active WebSocket responses sessions.
|
||||
func (h *WSResponsesHandler) Close() {
|
||||
if h == nil || h.sessions == nil {
|
||||
return
|
||||
}
|
||||
h.sessions.CloseAll()
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the WebSocket Responses endpoint at the base path
|
||||
// and all OpenAI integration paths.
|
||||
func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
|
||||
// Base path (outside integration prefix)
|
||||
r.GET("/v1/responses", handler)
|
||||
// OpenAI integration paths (/openai/v1/responses, /openai/responses, /openai/openai/responses)
|
||||
for _, path := range integrations.OpenAIWSResponsesPaths("/openai") {
|
||||
r.GET(path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop.
|
||||
func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
|
||||
err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
session, sessionErr := h.sessions.Create(conn)
|
||||
if sessionErr != nil {
|
||||
writeWSError(conn, 429, "websocket_connection_limit_reached", sessionErr.Error())
|
||||
return
|
||||
}
|
||||
defer h.sessions.Remove(conn)
|
||||
|
||||
// Capture auth headers from the upgrade request for per-event context creation
|
||||
authHeaders := captureAuthHeaders(ctx)
|
||||
|
||||
h.eventLoop(conn, session, authHeaders)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("websocket upgrade failed for /v1/responses: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// authHeaders holds auth-related headers captured during the WS upgrade.
|
||||
type authHeaders struct {
|
||||
authorization string
|
||||
virtualKey string
|
||||
apiKey string
|
||||
googAPIKey string
|
||||
baggage string
|
||||
extraHeaders map[string]string
|
||||
}
|
||||
|
||||
// captureAuthHeaders captures the auth headers from the request.
|
||||
func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders {
|
||||
ah := &authHeaders{
|
||||
authorization: string(ctx.Request.Header.Peek("Authorization")),
|
||||
virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")),
|
||||
apiKey: string(ctx.Request.Header.Peek("x-api-key")),
|
||||
googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")),
|
||||
baggage: string(ctx.Request.Header.Peek("baggage")),
|
||||
extraHeaders: make(map[string]string),
|
||||
}
|
||||
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
k := string(key)
|
||||
lk := strings.ToLower(k)
|
||||
if strings.HasPrefix(lk, "x-bf-") {
|
||||
ah.extraHeaders[k] = string(value)
|
||||
}
|
||||
}
|
||||
return ah
|
||||
}
|
||||
|
||||
// eventLoop reads events from the client WebSocket and processes them.
|
||||
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseNormalClosure) {
|
||||
logger.Warn("websocket read error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the event type
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := sonic.Unmarshal(message, &envelope); err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse event JSON")
|
||||
continue
|
||||
}
|
||||
|
||||
switch schemas.WebSocketEventType(envelope.Type) {
|
||||
case schemas.WSEventResponseCreate:
|
||||
h.handleResponseCreate(session, auth, message)
|
||||
default:
|
||||
writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponseCreate processes a response.create event.
|
||||
// Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge.
|
||||
// If native WS upstream fails mid-stream, falls back to HTTP bridge.
|
||||
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) {
|
||||
var event schemas.WebSocketResponsesEvent
|
||||
|
||||
if err := sonic.Unmarshal(message, &event); err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse response.create event")
|
||||
return
|
||||
}
|
||||
|
||||
// Store override: default to store=true (Codex sends false by default but expects true).
|
||||
// If DisableStore is set in provider config, force store=false.
|
||||
// If client explicitly sets store, respect that value unless DisableStore overrides it.
|
||||
provider, modelName := schemas.ParseModelString(event.Model, "")
|
||||
if provider == "" || modelName == "" {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse model string")
|
||||
return
|
||||
}
|
||||
|
||||
if providerCfg, cfgErr := h.config.GetProviderConfigRaw(provider); cfgErr == nil &&
|
||||
providerCfg.OpenAIConfig != nil && providerCfg.OpenAIConfig.DisableStore {
|
||||
event.Store = schemas.Ptr(false)
|
||||
} else {
|
||||
event.Store = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
bifrostReq, err := h.convertEventToRequest(&event)
|
||||
if err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Extract extra params (unknown fields) and forward them, matching the HTTP path behavior
|
||||
extraParams, extractErr := extractExtraParams(message, wsResponsesKnownFields)
|
||||
if extractErr == nil && len(extraParams) > 0 {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.ResponsesParameters{}
|
||||
}
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
|
||||
if bifrostCtx == nil {
|
||||
writeWSError(session, 500, "server_error", "failed to create request context")
|
||||
return
|
||||
}
|
||||
|
||||
// Try native WS upstream first
|
||||
if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Fall back to HTTP bridge
|
||||
h.executeHTTPBridge(session, bifrostCtx, cancel, bifrostReq)
|
||||
}
|
||||
|
||||
// tryNativeWSUpstream attempts to forward the event to a native WS upstream connection.
|
||||
// Returns true if the event was handled (successfully or with error sent to client).
|
||||
// Returns false if the provider doesn't support WS and we should fall back to HTTP bridge.
|
||||
func (h *WSResponsesHandler) tryNativeWSUpstream(
|
||||
session *bfws.Session,
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
rawEvent []byte,
|
||||
) bool {
|
||||
provider := h.client.GetProviderByKey(req.Provider)
|
||||
if provider == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
|
||||
if !ok || !wsProvider.SupportsWebSocketMode() {
|
||||
return false
|
||||
}
|
||||
|
||||
key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model)
|
||||
if err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
wsURL := wsProvider.WebSocketResponsesURL(key)
|
||||
upstream := session.Upstream()
|
||||
|
||||
// Validate the pinned upstream matches the current request's provider/key
|
||||
if upstream != nil && !upstream.IsClosed() &&
|
||||
(upstream.Provider() != req.Provider || upstream.KeyID() != key.ID) {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
upstream = nil
|
||||
}
|
||||
|
||||
// If no upstream connection pinned, get one from the pool or dial
|
||||
if upstream == nil || upstream.IsClosed() {
|
||||
headers := wsProvider.WebSocketHeaders(key)
|
||||
poolKey := bfws.PoolKey{
|
||||
Provider: req.Provider,
|
||||
KeyID: key.ID,
|
||||
Endpoint: wsURL,
|
||||
}
|
||||
|
||||
upstream, err = h.pool.Get(poolKey, headers)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get upstream WS connection for %s: %v, falling back to HTTP bridge", req.Provider, err)
|
||||
return false
|
||||
}
|
||||
session.SetUpstream(upstream)
|
||||
}
|
||||
|
||||
// Run plugin pre-hooks before forwarding to upstream
|
||||
bifrostReq := &schemas.BifrostRequest{
|
||||
RequestType: schemas.WebSocketResponsesRequest,
|
||||
ResponsesRequest: req,
|
||||
}
|
||||
|
||||
hooks, preErr := h.client.RunStreamPreHooks(ctx, bifrostReq)
|
||||
if preErr != nil {
|
||||
writeWSBifrostError(session, preErr)
|
||||
return true
|
||||
}
|
||||
defer hooks.Cleanup()
|
||||
|
||||
// If a plugin short-circuited with a cached response, write it and skip upstream
|
||||
if hooks.ShortCircuitResponse != nil {
|
||||
writeWSShortCircuitResponse(session, hooks.ShortCircuitResponse)
|
||||
return true
|
||||
}
|
||||
|
||||
// Forward the raw event to upstream
|
||||
if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil {
|
||||
logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err)
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
return false
|
||||
}
|
||||
|
||||
// Retrieve tracer and traceID for chunk accumulation
|
||||
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
||||
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
||||
|
||||
// Read response events from upstream and relay to client, running post-hooks per chunk
|
||||
forwardedAny := false
|
||||
for {
|
||||
msgType, data, readErr := upstream.ReadMessage()
|
||||
if readErr != nil {
|
||||
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
if !forwardedAny {
|
||||
return false
|
||||
}
|
||||
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")
|
||||
return true
|
||||
}
|
||||
|
||||
streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model)
|
||||
isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type)
|
||||
|
||||
if isTerminal {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
}
|
||||
|
||||
if streamResp != nil {
|
||||
resp := &schemas.BifrostResponse{ResponsesStreamResponse: streamResp}
|
||||
|
||||
if tracer != nil && traceID != "" {
|
||||
tracer.AddStreamingChunk(traceID, resp)
|
||||
}
|
||||
|
||||
_, postErr := hooks.PostHookRunner(ctx, resp, nil)
|
||||
if postErr != nil {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
writeWSBifrostError(session, postErr)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr := session.WriteMessage(msgType, data); writeErr != nil {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
return true
|
||||
}
|
||||
forwardedAny = true
|
||||
|
||||
if isTerminal {
|
||||
h.trackResponseID(session, data)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWSShortCircuitResponse writes a short-circuited plugin response as WS events.
|
||||
func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) {
|
||||
if resp.ResponsesResponse != nil {
|
||||
data, err := sonic.Marshal(resp.ResponsesResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := session.WriteMessage(ws.TextMessage, data); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.ResponsesResponse.ID != nil && *resp.ResponsesResponse.ID != "" {
|
||||
session.SetLastResponseID(*resp.ResponsesResponse.ID)
|
||||
}
|
||||
} else if resp.ResponsesStreamResponse != nil {
|
||||
data, err := sonic.Marshal(resp.ResponsesStreamResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
session.WriteMessage(ws.TextMessage, data)
|
||||
}
|
||||
}
|
||||
|
||||
// parseUpstreamWSEvent attempts to parse a raw upstream WS event into a BifrostResponsesStreamResponse.
|
||||
// It populates ExtraFields so downstream plugins (logging, tracing) can identify the request type.
|
||||
// Returns nil if the data cannot be parsed (non-fatal, the raw bytes are still relayed).
|
||||
func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model string) *schemas.BifrostResponsesStreamResponse {
|
||||
var streamResp schemas.BifrostResponsesStreamResponse
|
||||
if err := sonic.Unmarshal(data, &streamResp); err != nil {
|
||||
return nil
|
||||
}
|
||||
if streamResp.Type == "" {
|
||||
return nil
|
||||
}
|
||||
streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest
|
||||
streamResp.ExtraFields.Provider = provider
|
||||
streamResp.ExtraFields.OriginalModelRequested = model
|
||||
return &streamResp
|
||||
}
|
||||
|
||||
// isTerminalStreamType returns true if the event type signals the end of a response stream.
|
||||
func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool {
|
||||
switch t {
|
||||
case schemas.ResponsesStreamResponseTypeCompleted,
|
||||
schemas.ResponsesStreamResponseTypeFailed,
|
||||
schemas.ResponsesStreamResponseTypeIncomplete,
|
||||
schemas.ResponsesStreamResponseTypeError:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// trackResponseID extracts and stores the response ID from terminal events.
|
||||
func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte) {
|
||||
var envelope struct {
|
||||
Response struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"response"`
|
||||
}
|
||||
if err := sonic.Unmarshal(data, &envelope); err == nil && envelope.Response.ID != "" {
|
||||
session.SetLastResponseID(envelope.Response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest.
|
||||
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) {
|
||||
provider, modelName := schemas.ParseModelString(event.Model, "")
|
||||
if provider == "" || modelName == "" {
|
||||
return nil, errModelFormat
|
||||
}
|
||||
|
||||
var input []schemas.ResponsesMessage
|
||||
if event.Input != nil {
|
||||
// Try parsing as array first
|
||||
if err := sonic.Unmarshal(event.Input, &input); err != nil {
|
||||
// Try as string
|
||||
var inputStr string
|
||||
if strErr := sonic.Unmarshal(event.Input, &inputStr); strErr != nil {
|
||||
return nil, errInputRequired
|
||||
}
|
||||
input = []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: &inputStr},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(input) == 0 {
|
||||
return nil, errInputRequired
|
||||
}
|
||||
|
||||
params := &schemas.ResponsesParameters{}
|
||||
if event.Temperature != nil {
|
||||
params.Temperature = event.Temperature
|
||||
}
|
||||
if event.TopP != nil {
|
||||
params.TopP = event.TopP
|
||||
}
|
||||
if event.MaxOutputTokens != nil {
|
||||
params.MaxOutputTokens = event.MaxOutputTokens
|
||||
}
|
||||
if event.Instructions != "" {
|
||||
params.Instructions = &event.Instructions
|
||||
}
|
||||
if event.PreviousResponseID != "" {
|
||||
params.PreviousResponseID = &event.PreviousResponseID
|
||||
}
|
||||
if event.Store != nil {
|
||||
params.Store = event.Store
|
||||
}
|
||||
if event.Tools != nil {
|
||||
var tools []schemas.ResponsesTool
|
||||
if err := sonic.Unmarshal(event.Tools, &tools); err == nil {
|
||||
params.Tools = tools
|
||||
}
|
||||
}
|
||||
if event.ToolChoice != nil {
|
||||
var tc schemas.ResponsesToolChoice
|
||||
if err := sonic.Unmarshal(event.ToolChoice, &tc); err == nil {
|
||||
params.ToolChoice = &tc
|
||||
}
|
||||
}
|
||||
if event.Reasoning != nil {
|
||||
var reasoning schemas.ResponsesParametersReasoning
|
||||
if err := sonic.Unmarshal(event.Reasoning, &reasoning); err == nil {
|
||||
params.Reasoning = &reasoning
|
||||
}
|
||||
}
|
||||
if event.Text != nil {
|
||||
var text schemas.ResponsesTextConfig
|
||||
if err := sonic.Unmarshal(event.Text, &text); err == nil {
|
||||
params.Text = &text
|
||||
}
|
||||
}
|
||||
if event.Metadata != nil {
|
||||
var metadata map[string]any
|
||||
if err := sonic.Unmarshal(event.Metadata, &metadata); err == nil {
|
||||
params.Metadata = &metadata
|
||||
}
|
||||
}
|
||||
if event.Truncation != "" {
|
||||
params.Truncation = &event.Truncation
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.ModelProvider(provider),
|
||||
Model: modelName,
|
||||
Input: input,
|
||||
Params: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade.
|
||||
func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background())
|
||||
|
||||
if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
|
||||
}
|
||||
|
||||
if auth.virtualKey != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey)
|
||||
}
|
||||
|
||||
// Handle Bearer token with sk-bf- prefix (virtual key via Authorization header)
|
||||
if auth.authorization != "" {
|
||||
if strings.HasPrefix(auth.authorization, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth.authorization, "Bearer ")
|
||||
if strings.HasPrefix(token, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(token),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth.apiKey != "" {
|
||||
if strings.HasPrefix(auth.apiKey, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(auth.apiKey),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
if auth.googAPIKey != "" {
|
||||
if strings.HasPrefix(auth.googAPIKey, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(auth.googAPIKey),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward x-bf-* headers
|
||||
for k, v := range auth.extraHeaders {
|
||||
lk := strings.ToLower(k)
|
||||
switch {
|
||||
case lk == "x-bf-vk":
|
||||
// Already handled above
|
||||
case lk == "x-bf-api-key":
|
||||
ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v)
|
||||
case strings.HasPrefix(lk, "x-bf-eh-"):
|
||||
suffix := strings.TrimPrefix(lk, "x-bf-eh-")
|
||||
existing, _ := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string]string)
|
||||
if existing == nil {
|
||||
existing = make(map[string]string)
|
||||
}
|
||||
existing[suffix] = v
|
||||
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// executeHTTPBridge runs the response through the existing streaming inference pipeline.
|
||||
func (h *WSResponsesHandler) executeHTTPBridge(
|
||||
session *bfws.Session,
|
||||
ctx *schemas.BifrostContext,
|
||||
cancel context.CancelFunc,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
) {
|
||||
defer cancel()
|
||||
|
||||
stream, bifrostErr := h.client.ResponsesStreamRequest(ctx, req)
|
||||
if bifrostErr != nil {
|
||||
writeWSBifrostError(session, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Relay streaming chunks as WS messages
|
||||
for chunk := range stream {
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkJSON, err := sonic.Marshal(chunk)
|
||||
if err != nil {
|
||||
logger.Warn("failed to marshal stream chunk: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if writeErr := session.WriteMessage(ws.TextMessage, chunkJSON); writeErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Track last response ID for session chaining
|
||||
if chunk.BifrostResponsesStreamResponse != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Response != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Response.ID != nil &&
|
||||
*chunk.BifrostResponsesStreamResponse.Response.ID != "" {
|
||||
session.SetLastResponseID(*chunk.BifrostResponsesStreamResponse.Response.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWSError sends a JSON error event to a WebSocket write target.
|
||||
// Accepts either a raw *ws.Conn (pre-session) or a *bfws.Session (mutex-protected).
|
||||
func writeWSError(w wsWriter, status int, code, message string) {
|
||||
event := schemas.WebSocketErrorEvent{
|
||||
Type: schemas.WSEventError,
|
||||
Status: status,
|
||||
Error: &schemas.WebSocketErrorBody{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
data, err := sonic.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.WriteMessage(ws.TextMessage, data)
|
||||
}
|
||||
|
||||
// writeWSBifrostError converts a BifrostError to a WS error event.
|
||||
func writeWSBifrostError(w wsWriter, bifrostErr *schemas.BifrostError) {
|
||||
status := 500
|
||||
if bifrostErr.StatusCode != nil && *bifrostErr.StatusCode > 0 {
|
||||
status = *bifrostErr.StatusCode
|
||||
}
|
||||
code := "server_error"
|
||||
msg := "internal server error"
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Code != nil && *bifrostErr.Error.Code != "" {
|
||||
code = *bifrostErr.Error.Code
|
||||
} else if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
|
||||
code = *bifrostErr.Error.Type
|
||||
}
|
||||
if bifrostErr.Error.Message != "" {
|
||||
msg = bifrostErr.Error.Message
|
||||
}
|
||||
}
|
||||
writeWSError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// wsResponsesKnownFields lists the fields explicitly handled by WebSocketResponsesEvent.
|
||||
// Anything not in this set is treated as an extra param and forwarded as-is to the provider.
|
||||
var wsResponsesKnownFields = map[string]bool{
|
||||
"type": true,
|
||||
"model": true,
|
||||
"store": true,
|
||||
"input": true,
|
||||
"instructions": true,
|
||||
"previous_response_id": true,
|
||||
"tools": true,
|
||||
"tool_choice": true,
|
||||
"temperature": true,
|
||||
"top_p": true,
|
||||
"max_output_tokens": true,
|
||||
"reasoning": true,
|
||||
"metadata": true,
|
||||
"text": true,
|
||||
"truncation": true,
|
||||
}
|
||||
|
||||
var (
|
||||
errModelFormat = errorf("model should be in provider/model format")
|
||||
errInputRequired = errorf("input is required for responses")
|
||||
)
|
||||
|
||||
func errorf(msg string) error {
|
||||
return &simpleError{msg: msg}
|
||||
}
|
||||
|
||||
type simpleError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *simpleError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
68
transports/bifrost-http/handlers/wsresponses_test.go
Normal file
68
transports/bifrost-http/handlers/wsresponses_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
type testWSHandlerStore struct {
|
||||
allowDirectKeys bool
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) ShouldAllowDirectKeys() bool {
|
||||
return s.allowDirectKeys
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetAsyncJobResultTTL() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetKVStore() *kvstore.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s testWSHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCreateBifrostContextFromAuth_BaggageSessionIDSetsGrouping(t *testing.T) {
|
||||
ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{
|
||||
baggage: "foo=bar, session-id=rt-ws-123, baz=qux",
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if got, _ := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-ws-123" {
|
||||
t.Fatalf("parent request id = %q, want %q", got, "rt-ws-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T) {
|
||||
ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{
|
||||
baggage: "session-id= ",
|
||||
})
|
||||
defer cancel()
|
||||
|
||||
if got := ctx.Value(schemas.BifrostContextKeyParentRequestID); got != nil {
|
||||
t.Fatalf("parent request id should be unset, got %#v", got)
|
||||
}
|
||||
}
|
||||
1081
transports/bifrost-http/integrations/anthropic.go
Normal file
1081
transports/bifrost-http/integrations/anthropic.go
Normal file
File diff suppressed because it is too large
Load Diff
87
transports/bifrost-http/integrations/anthropic_test.go
Normal file
87
transports/bifrost-http/integrations/anthropic_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilterVertexUnsupportedBetaHeaders(t *testing.T) {
|
||||
t.Run("filters known exact header values", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"advanced-tool-use-2025-11-20,structured-outputs-2025-11-13,mcp-client-2025-04-04,prompt-caching-scope-2026-01-05"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
_, ok := result["anthropic-beta"]
|
||||
assert.False(t, ok, "all unsupported beta headers should be removed, leaving no anthropic-beta key")
|
||||
})
|
||||
|
||||
t.Run("filters bumped date variants", func(t *testing.T) {
|
||||
// Simulate Anthropic bumping version dates in the future
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2025-12-15,advanced-tool-use-2026-03-01,mcp-client-2026-01-01,prompt-caching-scope-2027-06-30"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
_, ok := result["anthropic-beta"]
|
||||
assert.False(t, ok, "bumped-date variants of unsupported headers should also be filtered")
|
||||
})
|
||||
|
||||
t.Run("passes through unrelated beta headers", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"interleaved-thinking-2025-05-14,files-api-2025-04-14"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["anthropic-beta"]
|
||||
assert.True(t, ok, "unrelated beta headers should be preserved")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14,files-api-2025-04-14"}, vals)
|
||||
})
|
||||
|
||||
t.Run("filters unsupported and keeps supported in mixed list", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"interleaved-thinking-2025-05-14,structured-outputs-2025-11-13,files-api-2025-04-14,mcp-client-2025-04-04"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["anthropic-beta"]
|
||||
assert.True(t, ok, "supported beta headers should be preserved")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14,files-api-2025-04-14"}, vals)
|
||||
})
|
||||
|
||||
t.Run("filters bumped unsupported mixed with supported", func(t *testing.T) {
|
||||
// Future-proof: bumped dates should still be filtered
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2026-01-01,interleaved-thinking-2025-05-14,advanced-tool-use-2026-06-15"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["anthropic-beta"]
|
||||
assert.True(t, ok, "supported beta headers should be preserved even when mixed with bumped unsupported ones")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, vals)
|
||||
})
|
||||
|
||||
t.Run("returns headers unchanged when no anthropic-beta key present", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"content-type": {"application/json"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
assert.Equal(t, headers, result)
|
||||
})
|
||||
|
||||
t.Run("handles empty anthropic-beta value gracefully", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {""},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
// Empty string after trimming is not an unsupported header, but it is also empty — key should be removed
|
||||
_, ok := result["anthropic-beta"]
|
||||
assert.False(t, ok, "empty beta header list should result in key removal")
|
||||
})
|
||||
|
||||
t.Run("case-insensitive key matching for Anthropic-Beta header", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"Anthropic-Beta": {"structured-outputs-2025-11-13,interleaved-thinking-2025-05-14"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["Anthropic-Beta"]
|
||||
assert.True(t, ok, "header key casing should be preserved and matching should be case-insensitive")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, vals)
|
||||
})
|
||||
}
|
||||
1315
transports/bifrost-http/integrations/bedrock.go
Normal file
1315
transports/bifrost-http/integrations/bedrock.go
Normal file
File diff suppressed because it is too large
Load Diff
921
transports/bifrost-http/integrations/bedrock_test.go
Normal file
921
transports/bifrost-http/integrations/bedrock_test.go
Normal file
@@ -0,0 +1,921 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/bedrock"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockHandlerStore implements lib.HandlerStore for testing
|
||||
type mockHandlerStore struct {
|
||||
allowDirectKeys bool
|
||||
headerMatcher *lib.HeaderMatcher
|
||||
availableProviders []schemas.ModelProvider
|
||||
mcpHeaderCombinedAllowlist schemas.WhiteList
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) ShouldAllowDirectKeys() bool {
|
||||
return m.allowDirectKeys
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
|
||||
return m.headerMatcher
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
|
||||
return m.availableProviders
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetAsyncJobResultTTL() int {
|
||||
return 3600
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetKVStore() *kvstore.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
|
||||
return m.mcpHeaderCombinedAllowlist
|
||||
}
|
||||
|
||||
// Ensure mockHandlerStore implements lib.HandlerStore
|
||||
var _ lib.HandlerStore = (*mockHandlerStore)(nil)
|
||||
|
||||
func Test_parseS3URI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
wantBucket string
|
||||
wantKey string
|
||||
}{
|
||||
{
|
||||
name: "full S3 URI with key",
|
||||
uri: "s3://my-bucket/path/to/file.jsonl",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "path/to/file.jsonl",
|
||||
},
|
||||
{
|
||||
name: "S3 URI with bucket only",
|
||||
uri: "s3://my-bucket/",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "",
|
||||
},
|
||||
{
|
||||
name: "S3 URI with bucket no trailing slash",
|
||||
uri: "s3://my-bucket",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "",
|
||||
},
|
||||
{
|
||||
name: "plain bucket name",
|
||||
uri: "my-bucket",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "",
|
||||
},
|
||||
{
|
||||
name: "S3 URI with nested key",
|
||||
uri: "s3://bucket-name/folder1/folder2/file.txt",
|
||||
wantBucket: "bucket-name",
|
||||
wantKey: "folder1/folder2/file.txt",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
uri: "",
|
||||
wantBucket: "",
|
||||
wantKey: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBucket, gotKey := parseS3URI(tt.uri)
|
||||
assert.Equal(t, tt.wantBucket, gotBucket, "bucket mismatch")
|
||||
assert.Equal(t, tt.wantKey, gotKey, "key mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockRouteConfigs(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := CreateBedrockRouteConfigs("/bedrock", handlerStore)
|
||||
|
||||
assert.Len(t, routes, 6, "should have 6 bedrock routes")
|
||||
|
||||
expectedRoutes := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/bedrock/model/{modelId}/converse", "POST"},
|
||||
{"/bedrock/model/{modelId}/converse-stream", "POST"},
|
||||
{"/bedrock/model/{modelId}/invoke-with-response-stream", "POST"},
|
||||
{"/bedrock/model/{modelId}/invoke", "POST"},
|
||||
{"/bedrock/rerank", "POST"},
|
||||
{"/bedrock/model/{modelId}/count-tokens", "POST"},
|
||||
}
|
||||
|
||||
for i, expected := range expectedRoutes {
|
||||
assert.Equal(t, expected.path, routes[i].Path, "route %d path mismatch", i)
|
||||
assert.Equal(t, expected.method, routes[i].Method, "route %d method mismatch", i)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "route %d type mismatch", i)
|
||||
assert.NotNil(t, routes[i].GetRequestTypeInstance, "route %d GetRequestTypeInstance should not be nil", i)
|
||||
assert.NotNil(t, routes[i].ErrorConverter, "route %d ErrorConverter should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockConverseRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockConverseRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/converse", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.GetRequestTypeInstance)
|
||||
assert.NotNil(t, route.RequestConverter)
|
||||
assert.NotNil(t, route.ResponsesResponseConverter)
|
||||
assert.NotNil(t, route.ErrorConverter)
|
||||
assert.NotNil(t, route.PreCallback)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockConverseRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockConverseRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockConverseStreamRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockConverseStreamRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/converse-stream", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.StreamConfig)
|
||||
assert.NotNil(t, route.StreamConfig.ResponsesStreamResponseConverter)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockConverseRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockConverseRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockInvokeRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockInvokeRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/invoke", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.TextResponseConverter)
|
||||
assert.NotNil(t, route.ResponsesResponseConverter)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockInvokeRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockInvokeRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockInvokeWithResponseStreamRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockInvokeWithResponseStreamRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/invoke-with-response-stream", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.StreamConfig)
|
||||
assert.NotNil(t, route.StreamConfig.TextStreamResponseConverter)
|
||||
assert.NotNil(t, route.StreamConfig.ResponsesStreamResponseConverter)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockInvokeRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockInvokeRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockRerankRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/rerank", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.GetHTTPRequestType)
|
||||
assert.Equal(t, schemas.RerankRequest, route.GetHTTPRequestType(nil))
|
||||
assert.NotNil(t, route.GetRequestTypeInstance)
|
||||
assert.NotNil(t, route.RequestConverter)
|
||||
assert.NotNil(t, route.RerankResponseConverter)
|
||||
assert.NotNil(t, route.ErrorConverter)
|
||||
assert.NotNil(t, route.PreCallback)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockRerankRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockRerankRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockRerankResponseConverterUsesRawResponse(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
|
||||
require.NotNil(t, route.RerankResponseConverter)
|
||||
|
||||
raw := map[string]interface{}{"results": []interface{}{}}
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Bedrock,
|
||||
RawResponse: raw,
|
||||
},
|
||||
}
|
||||
converted, err := route.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, raw, converted)
|
||||
}
|
||||
|
||||
func Test_createBedrockRerankRouteRequestConverter(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
|
||||
require.NotNil(t, route.RequestConverter)
|
||||
|
||||
topN := 1
|
||||
req := &bedrock.BedrockRerankRequest{
|
||||
Queries: []bedrock.BedrockRerankQuery{
|
||||
{
|
||||
Type: "TEXT",
|
||||
TextQuery: bedrock.BedrockRerankTextRef{Text: "capital of france"},
|
||||
},
|
||||
},
|
||||
Sources: []bedrock.BedrockRerankSource{
|
||||
{
|
||||
Type: "INLINE",
|
||||
InlineDocumentSource: bedrock.BedrockRerankInlineSource{
|
||||
Type: "TEXT",
|
||||
TextDocument: bedrock.BedrockRerankTextValue{Text: "Paris is capital of France"},
|
||||
},
|
||||
},
|
||||
},
|
||||
RerankingConfiguration: bedrock.BedrockRerankingConfiguration{
|
||||
Type: "BEDROCK_RERANKING_MODEL",
|
||||
BedrockRerankingConfiguration: bedrock.BedrockRerankingModelConfiguration{
|
||||
NumberOfResults: &topN,
|
||||
ModelConfiguration: bedrock.BedrockRerankModelConfiguration{
|
||||
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostReq, err := route.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.RerankRequest)
|
||||
assert.Equal(t, schemas.Bedrock, bifrostReq.RerankRequest.Provider)
|
||||
assert.Equal(t, "capital of france", bifrostReq.RerankRequest.Query)
|
||||
require.Len(t, bifrostReq.RerankRequest.Documents, 1)
|
||||
assert.Equal(t, "Paris is capital of France", bifrostReq.RerankRequest.Documents[0].Text)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
|
||||
assert.Equal(t, 1, *bifrostReq.RerankRequest.Params.TopN)
|
||||
}
|
||||
|
||||
func Test_createBedrockRouteConfigsIncludesRerankForCompositePrefixes(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
prefixes := []string{"/litellm", "/langchain", "/pydanticai"}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
routes := CreateBedrockRouteConfigs(prefix, handlerStore)
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == prefix+"/rerank" && route.Method == "POST" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Truef(t, found, "expected rerank route for prefix %s", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockBatchRouteConfigs(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := createBedrockBatchRouteConfigs("/bedrock", handlerStore)
|
||||
|
||||
assert.Len(t, routes, 4, "should have 4 batch routes")
|
||||
|
||||
expectedRoutes := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/bedrock/model-invocation-job", "POST"},
|
||||
{"/bedrock/model-invocation-jobs", "GET"},
|
||||
{"/bedrock/model-invocation-job/{job_arn}", "GET"},
|
||||
{"/bedrock/model-invocation-job/{job_arn}/stop", "POST"},
|
||||
}
|
||||
|
||||
for i, expected := range expectedRoutes {
|
||||
assert.Equal(t, expected.path, routes[i].Path, "batch route %d path mismatch", i)
|
||||
assert.Equal(t, expected.method, routes[i].Method, "batch route %d method mismatch", i)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "batch route %d type mismatch", i)
|
||||
assert.NotNil(t, routes[i].GetRequestTypeInstance, "batch route %d GetRequestTypeInstance should not be nil", i)
|
||||
assert.NotNil(t, routes[i].BatchRequestConverter, "batch route %d BatchCreateRequestConverter should not be nil", i)
|
||||
assert.NotNil(t, routes[i].ErrorConverter, "batch route %d ErrorConverter should not be nil", i)
|
||||
assert.NotNil(t, routes[i].PreCallback, "batch route %d PreCallback should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockFilesRouteConfigs(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := createBedrockFilesRouteConfigs("/bedrock/files", handlerStore)
|
||||
|
||||
assert.Len(t, routes, 5, "should have 5 file routes")
|
||||
|
||||
expectedRoutes := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/bedrock/files/{bucket}/{key:*}", "PUT"},
|
||||
{"/bedrock/files/{bucket}/{key:*}", "GET"},
|
||||
{"/bedrock/files/{bucket}/{key:*}", "HEAD"},
|
||||
{"/bedrock/files/{bucket}/{key:*}", "DELETE"},
|
||||
{"/bedrock/files/{bucket}", "GET"},
|
||||
}
|
||||
|
||||
for i, expected := range expectedRoutes {
|
||||
assert.Equal(t, expected.path, routes[i].Path, "file route %d path mismatch", i)
|
||||
assert.Equal(t, expected.method, routes[i].Method, "file route %d method mismatch", i)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "file route %d type mismatch", i)
|
||||
assert.NotNil(t, routes[i].GetRequestTypeInstance, "file route %d GetRequestTypeInstance should not be nil", i)
|
||||
assert.NotNil(t, routes[i].ErrorConverter, "file route %d ErrorConverter should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseS3PutObjectRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bucket string
|
||||
key string
|
||||
body []byte
|
||||
wantErr bool
|
||||
wantBucket string
|
||||
wantKey string
|
||||
wantFilename string
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
bucket: "my-bucket",
|
||||
key: "folder/file.jsonl",
|
||||
body: []byte(`{"test": "data"}`),
|
||||
wantErr: false,
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "folder/file.jsonl",
|
||||
wantFilename: "file.jsonl",
|
||||
},
|
||||
{
|
||||
name: "simple key without folder",
|
||||
bucket: "bucket",
|
||||
key: "file.txt",
|
||||
body: []byte("content"),
|
||||
wantErr: false,
|
||||
wantBucket: "bucket",
|
||||
wantKey: "file.txt",
|
||||
wantFilename: "file.txt",
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
bucket: "",
|
||||
key: "file.txt",
|
||||
body: []byte("content"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
bucket: "bucket",
|
||||
key: "",
|
||||
body: []byte("content"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetBody(tt.body)
|
||||
|
||||
if tt.bucket != "" {
|
||||
ctx.SetUserValue("bucket", tt.bucket)
|
||||
}
|
||||
if tt.key != "" {
|
||||
ctx.SetUserValue("key", tt.key)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockFileUploadRequest{}
|
||||
err := parseS3PutObjectRequest(ctx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantBucket, req.Bucket)
|
||||
assert.Equal(t, tt.wantKey, req.Key)
|
||||
assert.Equal(t, tt.wantFilename, req.Filename)
|
||||
assert.Equal(t, "batch", req.Purpose)
|
||||
assert.Equal(t, tt.body, req.Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseS3PutObjectRequest_invalidType(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("bucket", "bucket")
|
||||
ctx.SetUserValue("key", "key")
|
||||
|
||||
// Pass wrong type
|
||||
err := parseS3PutObjectRequest(ctx, "invalid type")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid request type")
|
||||
}
|
||||
|
||||
func Test_s3PutObjectPostCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
wantStatus int
|
||||
wantETag string
|
||||
}{
|
||||
{
|
||||
name: "valid response with ID",
|
||||
response: &schemas.BifrostFileUploadResponse{
|
||||
ID: "file-123",
|
||||
},
|
||||
wantStatus: 200,
|
||||
wantETag: "\"file-123\"",
|
||||
},
|
||||
{
|
||||
name: "nil response",
|
||||
response: nil,
|
||||
wantStatus: 200,
|
||||
wantETag: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3PutObjectPostCallback(ctx, nil, tt.response)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantStatus, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "application/xml", string(ctx.Response.Header.ContentType()))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
|
||||
if tt.wantETag != "" {
|
||||
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_s3GetObjectPostCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
wantContentType string
|
||||
wantLength string
|
||||
wantETag string
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
response: &schemas.BifrostFileContentResponse{
|
||||
Content: []byte("test content"),
|
||||
ContentType: "application/json",
|
||||
FileID: "file-456",
|
||||
},
|
||||
wantContentType: "application/json",
|
||||
wantLength: "12",
|
||||
wantETag: "\"file-456\"",
|
||||
},
|
||||
{
|
||||
name: "nil response",
|
||||
response: nil,
|
||||
wantContentType: "",
|
||||
wantLength: "",
|
||||
wantETag: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3GetObjectPostCallback(ctx, nil, tt.response)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tt.wantContentType != "" {
|
||||
assert.Equal(t, tt.wantContentType, string(ctx.Response.Header.Peek("Content-Type")))
|
||||
assert.Equal(t, tt.wantLength, string(ctx.Response.Header.Peek("Content-Length")))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
}
|
||||
|
||||
if tt.wantETag != "" {
|
||||
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_s3HeadObjectPostCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
wantStatus int
|
||||
wantLength string
|
||||
wantETag string
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
response: &schemas.BifrostFileRetrieveResponse{
|
||||
ID: "file-789",
|
||||
Bytes: 1024,
|
||||
},
|
||||
wantStatus: 200,
|
||||
wantLength: "1024",
|
||||
wantETag: "\"file-789\"",
|
||||
},
|
||||
{
|
||||
name: "nil response",
|
||||
response: nil,
|
||||
wantStatus: 200,
|
||||
wantLength: "",
|
||||
wantETag: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3HeadObjectPostCallback(ctx, nil, tt.response)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantStatus, ctx.Response.StatusCode())
|
||||
|
||||
if tt.wantLength != "" {
|
||||
assert.Equal(t, "application/octet-stream", string(ctx.Response.Header.Peek("Content-Type")))
|
||||
assert.Equal(t, tt.wantLength, string(ctx.Response.Header.Peek("Content-Length")))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_s3DeleteObjectPostCallback(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3DeleteObjectPostCallback(ctx, nil, nil)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 204, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
}
|
||||
|
||||
func Test_s3ListObjectsV2PostCallback(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3ListObjectsV2PostCallback(ctx, nil, nil)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "application/xml", string(ctx.Response.Header.ContentType()))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
}
|
||||
|
||||
func Test_extractBedrockBatchListQueryParams(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams map[string]string
|
||||
wantMaxResults int
|
||||
wantNextToken string
|
||||
wantStatus string
|
||||
wantName string
|
||||
}{
|
||||
{
|
||||
name: "all params",
|
||||
queryParams: map[string]string{
|
||||
"maxResults": "50",
|
||||
"nextToken": "token123",
|
||||
"statusEquals": "InProgress",
|
||||
"nameContains": "test-job",
|
||||
},
|
||||
wantMaxResults: 50,
|
||||
wantNextToken: "token123",
|
||||
wantStatus: "InProgress",
|
||||
wantName: "test-job",
|
||||
},
|
||||
{
|
||||
name: "no params",
|
||||
queryParams: map[string]string{},
|
||||
wantMaxResults: 0,
|
||||
wantNextToken: "",
|
||||
wantStatus: "",
|
||||
wantName: "",
|
||||
},
|
||||
{
|
||||
name: "invalid maxResults",
|
||||
queryParams: map[string]string{
|
||||
"maxResults": "invalid",
|
||||
},
|
||||
wantMaxResults: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
for k, v := range tt.queryParams {
|
||||
ctx.QueryArgs().Add(k, v)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockBatchListRequest{}
|
||||
callback := extractBedrockBatchListQueryParams(handlerStore)
|
||||
|
||||
bifrostCtx := createTestBifrostContext()
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantMaxResults, req.MaxResults)
|
||||
assert.Equal(t, tt.wantStatus, req.StatusEquals)
|
||||
assert.Equal(t, tt.wantName, req.NameContains)
|
||||
|
||||
if tt.wantNextToken != "" {
|
||||
assert.NotNil(t, req.NextToken)
|
||||
assert.Equal(t, tt.wantNextToken, *req.NextToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractBedrockJobArnFromPath(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
jobArn interface{}
|
||||
provider schemas.ModelProvider
|
||||
wantErr bool
|
||||
wantJobArn string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid job ARN for Bedrock",
|
||||
jobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: false,
|
||||
wantJobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
|
||||
},
|
||||
{
|
||||
name: "URL encoded job ARN",
|
||||
jobArn: "arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Abatch%3Ajob-123",
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: false,
|
||||
wantJobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
|
||||
},
|
||||
{
|
||||
name: "non-Bedrock provider strips ARN prefix",
|
||||
jobArn: "arn:aws:bedrock:us-east-1:444444444444:batch:job-456",
|
||||
provider: schemas.OpenAI,
|
||||
wantErr: false,
|
||||
wantJobArn: "job-456",
|
||||
},
|
||||
{
|
||||
name: "missing job_arn",
|
||||
jobArn: nil,
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: true,
|
||||
errContains: "job_arn is required",
|
||||
},
|
||||
{
|
||||
name: "empty job_arn",
|
||||
jobArn: "",
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: true,
|
||||
errContains: "job_arn must be a non-empty string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.jobArn != nil {
|
||||
ctx.SetUserValue("job_arn", tt.jobArn)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockBatchRetrieveRequest{}
|
||||
callback := extractBedrockJobArnFromPath(handlerStore)
|
||||
|
||||
bifrostCtx := createTestBifrostContextWithProvider(tt.provider)
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantJobArn, req.JobIdentifier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractS3ListObjectsV2Params(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bucket string
|
||||
queryParams map[string]string
|
||||
wantErr bool
|
||||
wantBucket string
|
||||
wantPrefix string
|
||||
wantMaxKeys int
|
||||
wantContinuationToken string
|
||||
}{
|
||||
{
|
||||
name: "all params",
|
||||
bucket: "my-bucket",
|
||||
queryParams: map[string]string{
|
||||
"prefix": "folder/",
|
||||
"max-keys": "100",
|
||||
"continuation-token": "token-abc",
|
||||
},
|
||||
wantErr: false,
|
||||
wantBucket: "my-bucket",
|
||||
wantPrefix: "folder/",
|
||||
wantMaxKeys: 100,
|
||||
wantContinuationToken: "token-abc",
|
||||
},
|
||||
{
|
||||
name: "bucket only",
|
||||
bucket: "simple-bucket",
|
||||
queryParams: map[string]string{},
|
||||
wantErr: false,
|
||||
wantBucket: "simple-bucket",
|
||||
wantPrefix: "",
|
||||
wantMaxKeys: 1000,
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
bucket: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.bucket != "" {
|
||||
ctx.SetUserValue("bucket", tt.bucket)
|
||||
}
|
||||
for k, v := range tt.queryParams {
|
||||
ctx.QueryArgs().Add(k, v)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockFileListRequest{}
|
||||
callback := extractS3ListObjectsV2Params(handlerStore)
|
||||
|
||||
bifrostCtx := createTestBifrostContext()
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantBucket, req.Bucket)
|
||||
assert.Equal(t, tt.wantPrefix, req.Prefix)
|
||||
assert.Equal(t, tt.wantMaxKeys, req.MaxKeys)
|
||||
assert.Equal(t, tt.wantContinuationToken, req.ContinuationToken)
|
||||
|
||||
// Verify context values
|
||||
assert.Equal(t, tt.wantBucket, bifrostCtx.Value(s3ContextKeyBucket))
|
||||
assert.Equal(t, tt.wantPrefix, bifrostCtx.Value(s3ContextKeyPrefix))
|
||||
assert.Equal(t, tt.wantMaxKeys, bifrostCtx.Value(s3ContextKeyMaxKeys))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractS3BucketKeyFromPath(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bucket string
|
||||
key string
|
||||
fileID string
|
||||
opType string
|
||||
wantErr bool
|
||||
wantBucket string
|
||||
wantKey string
|
||||
wantS3URI string
|
||||
}{
|
||||
{
|
||||
name: "content operation",
|
||||
bucket: "my-bucket",
|
||||
key: "path/to/file.txt",
|
||||
fileID: "file-123",
|
||||
opType: "content",
|
||||
wantErr: false,
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "path/to/file.txt",
|
||||
wantS3URI: "s3://my-bucket/path/to/file.txt",
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
bucket: "",
|
||||
key: "file.txt",
|
||||
opType: "content",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
bucket: "bucket",
|
||||
key: "",
|
||||
opType: "content",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.bucket != "" {
|
||||
ctx.SetUserValue("bucket", tt.bucket)
|
||||
}
|
||||
if tt.key != "" {
|
||||
ctx.SetUserValue("key", tt.key)
|
||||
}
|
||||
if tt.fileID != "" {
|
||||
ctx.Request.Header.Set("If-Match", tt.fileID)
|
||||
}
|
||||
|
||||
callback := extractS3BucketKeyFromPath(handlerStore, tt.opType)
|
||||
bifrostCtx := createTestBifrostContext()
|
||||
|
||||
var req interface{}
|
||||
switch tt.opType {
|
||||
case "content":
|
||||
req = &bedrock.BedrockFileContentRequest{}
|
||||
case "retrieve":
|
||||
req = &bedrock.BedrockFileRetrieveRequest{}
|
||||
case "delete":
|
||||
req = &bedrock.BedrockFileDeleteRequest{}
|
||||
}
|
||||
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
switch r := req.(type) {
|
||||
case *bedrock.BedrockFileContentRequest:
|
||||
assert.Equal(t, tt.wantBucket, r.Bucket)
|
||||
assert.Equal(t, tt.wantKey, r.Prefix)
|
||||
assert.Equal(t, tt.wantS3URI, r.S3Uri)
|
||||
assert.Equal(t, tt.fileID, r.ETag)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating test contexts
|
||||
|
||||
func createTestBifrostContext() *schemas.BifrostContext {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock)
|
||||
return bifrostCtx
|
||||
}
|
||||
|
||||
func createTestBifrostContextWithProvider(provider schemas.ModelProvider) *schemas.BifrostContext {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(bifrostContextKeyProvider, provider)
|
||||
return bifrostCtx
|
||||
}
|
||||
222
transports/bifrost-http/integrations/cohere.go
Normal file
222
transports/bifrost-http/integrations/cohere.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/cohere"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// hydrateCohereRequestFromLargePayloadMetadata populates model + stream from
|
||||
// LargePayloadMetadata when body parsing is skipped under large payload mode.
|
||||
func hydrateCohereRequestFromLargePayloadMetadata(bifrostCtx *schemas.BifrostContext, req interface{}) {
|
||||
if bifrostCtx == nil {
|
||||
return
|
||||
}
|
||||
isLargePayload, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
|
||||
if !isLargePayload {
|
||||
return
|
||||
}
|
||||
metadata := resolveLargePayloadMetadata(bifrostCtx)
|
||||
if metadata == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch r := req.(type) {
|
||||
case *cohere.CohereChatRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
if metadata.StreamRequested != nil && r.Stream == nil {
|
||||
r.Stream = schemas.Ptr(*metadata.StreamRequested)
|
||||
}
|
||||
case *cohere.CohereEmbeddingRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
case *cohere.CohereRerankRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
case *cohere.CohereCountTokensRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cohereLargePayloadPreHook populates model + stream from LargePayloadMetadata
|
||||
// when body parsing is skipped under large payload mode.
|
||||
func cohereLargePayloadPreHook(_ *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error {
|
||||
hydrateCohereRequestFromLargePayloadMetadata(bifrostCtx, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CohereRouter holds route registrations for Cohere endpoints.
|
||||
// It supports Cohere's v2 chat, embeddings, and rerank APIs.
|
||||
type CohereRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewCohereRouter creates a new CohereRouter with the given bifrost client.
|
||||
func NewCohereRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *CohereRouter {
|
||||
return &CohereRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, CreateCohereRouteConfigs("/cohere"), nil, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCohereRouteConfigs creates route configurations for Cohere API endpoints.
|
||||
func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
|
||||
var routes []RouteConfig
|
||||
|
||||
// Chat completions endpoint (v2/chat)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v2/chat",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.ChatCompletionRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereChatRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereChatRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
ChatRequest: cohereReq.ToBifrostChatRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid request type")
|
||||
},
|
||||
ChatResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
StreamConfig: &StreamConfig{
|
||||
ChatStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (string, interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return "", resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return "", resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Embeddings endpoint (v2/embed)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v2/embed",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.EmbeddingRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereEmbeddingRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
EmbeddingRequest: cohereReq.ToBifrostEmbeddingRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid embedding request type")
|
||||
},
|
||||
EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
// Rerank endpoint (v2/rerank)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v2/rerank",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.RerankRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereRerankRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereRerankRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
RerankRequest: cohereReq.ToBifrostRerankRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid rerank request type")
|
||||
},
|
||||
RerankResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostRerankResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
// Tokenize endpoint (v1/tokenize)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v1/tokenize",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.CountTokensRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereCountTokensRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
CountTokensRequest: cohereReq.ToBifrostResponsesRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid count tokens request type")
|
||||
},
|
||||
CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
return routes
|
||||
}
|
||||
102
transports/bifrost-http/integrations/cohere_test.go
Normal file
102
transports/bifrost-http/integrations/cohere_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/cohere"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateCohereRouteConfigsIncludesRerank(t *testing.T) {
|
||||
routes := CreateCohereRouteConfigs("/cohere")
|
||||
|
||||
assert.Len(t, routes, 4, "should have 4 cohere routes")
|
||||
|
||||
var rerankRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/cohere/v2/rerank" && routes[i].Method == "POST" {
|
||||
rerankRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, rerankRoute, "rerank route should exist")
|
||||
assert.Equal(t, RouteConfigTypeCohere, rerankRoute.Type)
|
||||
assert.NotNil(t, rerankRoute.GetHTTPRequestType)
|
||||
assert.Equal(t, schemas.RerankRequest, rerankRoute.GetHTTPRequestType(nil))
|
||||
assert.NotNil(t, rerankRoute.GetRequestTypeInstance)
|
||||
assert.NotNil(t, rerankRoute.RequestConverter)
|
||||
assert.NotNil(t, rerankRoute.RerankResponseConverter)
|
||||
assert.NotNil(t, rerankRoute.ErrorConverter)
|
||||
|
||||
reqInstance := rerankRoute.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*cohere.CohereRerankRequest)
|
||||
assert.True(t, ok, "rerank request instance should be CohereRerankRequest")
|
||||
}
|
||||
|
||||
func TestCohereRerankRouteRequestConverter(t *testing.T) {
|
||||
routes := CreateCohereRouteConfigs("/cohere")
|
||||
|
||||
var rerankRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/cohere/v2/rerank" {
|
||||
rerankRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, rerankRoute)
|
||||
require.NotNil(t, rerankRoute.RequestConverter)
|
||||
|
||||
topN := 1
|
||||
req := &cohere.CohereRerankRequest{
|
||||
Model: "rerank-v3.5",
|
||||
Query: "what is bifrost?",
|
||||
Documents: []string{"doc1", "doc2"},
|
||||
TopN: &topN,
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostReq, err := rerankRoute.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.RerankRequest)
|
||||
|
||||
assert.Equal(t, schemas.Cohere, bifrostReq.RerankRequest.Provider)
|
||||
assert.Equal(t, "rerank-v3.5", bifrostReq.RerankRequest.Model)
|
||||
assert.Equal(t, "what is bifrost?", bifrostReq.RerankRequest.Query)
|
||||
require.Len(t, bifrostReq.RerankRequest.Documents, 2)
|
||||
assert.Equal(t, "doc1", bifrostReq.RerankRequest.Documents[0].Text)
|
||||
assert.Equal(t, "doc2", bifrostReq.RerankRequest.Documents[1].Text)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
|
||||
assert.Equal(t, 1, *bifrostReq.RerankRequest.Params.TopN)
|
||||
}
|
||||
|
||||
func TestCohereRerankResponseConverterUsesRawResponse(t *testing.T) {
|
||||
routes := CreateCohereRouteConfigs("/cohere")
|
||||
|
||||
var rerankRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/cohere/v2/rerank" {
|
||||
rerankRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, rerankRoute)
|
||||
require.NotNil(t, rerankRoute.RerankResponseConverter)
|
||||
|
||||
raw := map[string]interface{}{"id": "r-123", "results": []interface{}{}}
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Cohere,
|
||||
RawResponse: raw,
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := rerankRoute.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, raw, converted)
|
||||
}
|
||||
1067
transports/bifrost-http/integrations/cursor.go
Normal file
1067
transports/bifrost-http/integrations/cursor.go
Normal file
File diff suppressed because it is too large
Load Diff
1347
transports/bifrost-http/integrations/genai.go
Normal file
1347
transports/bifrost-http/integrations/genai.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,47 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestExtractModelAndRequestType_LargePayloadUsesMetadataWithoutBodyParse(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("model", "gemini-2.5-pro:generateContent")
|
||||
// Intentionally invalid JSON: detection must rely on large-payload metadata, not body parse.
|
||||
ctx.Request.SetBodyString(`{"contents":[INVALID`)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMode, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, &schemas.LargePayloadMetadata{
|
||||
ResponseModalities: []string{"AUDIO"},
|
||||
})
|
||||
ctx.SetUserValue(lib.FastHTTPUserValueBifrostContext, bifrostCtx)
|
||||
|
||||
model, reqType := extractModelAndRequestType(ctx)
|
||||
if model != "gemini-2.5-pro" {
|
||||
t.Fatalf("expected normalized model gemini-2.5-pro, got %q", model)
|
||||
}
|
||||
if reqType != schemas.SpeechRequest {
|
||||
t.Fatalf("expected speech request type from metadata, got %q", reqType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractModelAndRequestType_LargeBodyHeuristicSkipsParse(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("model", "gemini-2.5-pro:generateContent")
|
||||
ctx.Request.SetBodyStream(strings.NewReader(`{"contents":[INVALID`), schemas.DefaultLargePayloadRequestThresholdBytes+1)
|
||||
|
||||
model, reqType := extractModelAndRequestType(ctx)
|
||||
if model != "gemini-2.5-pro" {
|
||||
t.Fatalf("expected normalized model gemini-2.5-pro, got %q", model)
|
||||
}
|
||||
if reqType != schemas.ResponsesRequest {
|
||||
t.Fatalf("expected responses request type from large-body heuristic, got %q", reqType)
|
||||
}
|
||||
}
|
||||
208
transports/bifrost-http/integrations/genai_test.go
Normal file
208
transports/bifrost-http/integrations/genai_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/gemini"
|
||||
"github.com/maximhq/bifrost/core/providers/vertex"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestCreateGenAIRerankRouteConfig(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
|
||||
assert.Equal(t, "/genai/v1/rank", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeGenAI, route.Type)
|
||||
assert.NotNil(t, route.GetHTTPRequestType)
|
||||
assert.Equal(t, schemas.RerankRequest, route.GetHTTPRequestType(nil))
|
||||
assert.NotNil(t, route.GetRequestTypeInstance)
|
||||
assert.NotNil(t, route.RequestConverter)
|
||||
assert.NotNil(t, route.RerankResponseConverter)
|
||||
assert.NotNil(t, route.ErrorConverter)
|
||||
assert.Nil(t, route.PreCallback)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*vertex.VertexRankRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *vertex.VertexRankRequest")
|
||||
}
|
||||
|
||||
func TestCreateGenAIRouteConfigsIncludesRerank(t *testing.T) {
|
||||
routes := CreateGenAIRouteConfigs("/genai")
|
||||
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == "/genai/v1/rank" && route.Method == "POST" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected rerank route in genai route configs")
|
||||
}
|
||||
|
||||
func TestCreateGenAIRouteConfigsIncludesRerankForCompositePrefixes(t *testing.T) {
|
||||
prefixes := []string{"/litellm", "/langchain", "/pydanticai"}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
routes := CreateGenAIRouteConfigs(prefix)
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == prefix+"/v1/rank" && route.Method == "POST" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Truef(t, found, "expected rerank route for prefix %s", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenAIRerankRequestConverter(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
require.NotNil(t, route.RequestConverter)
|
||||
|
||||
model := "semantic-ranker-default@latest"
|
||||
topN := 2
|
||||
content1 := "Paris is capital of France"
|
||||
content2 := "Berlin is capital of Germany"
|
||||
req := &vertex.VertexRankRequest{
|
||||
Model: &model,
|
||||
Query: "capital of france",
|
||||
Records: []vertex.VertexRankRecord{
|
||||
{ID: "rec-1", Content: &content1},
|
||||
{ID: "rec-2", Content: &content2},
|
||||
},
|
||||
TopN: &topN,
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostReq, err := route.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.RerankRequest)
|
||||
assert.Equal(t, schemas.Vertex, bifrostReq.RerankRequest.Provider)
|
||||
assert.Equal(t, "semantic-ranker-default@latest", bifrostReq.RerankRequest.Model)
|
||||
assert.Equal(t, "capital of france", bifrostReq.RerankRequest.Query)
|
||||
require.Len(t, bifrostReq.RerankRequest.Documents, 2)
|
||||
assert.Equal(t, "Paris is capital of France", bifrostReq.RerankRequest.Documents[0].Text)
|
||||
assert.Equal(t, "Berlin is capital of Germany", bifrostReq.RerankRequest.Documents[1].Text)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
|
||||
assert.Equal(t, 2, *bifrostReq.RerankRequest.Params.TopN)
|
||||
}
|
||||
|
||||
func TestGenAIRerankResponseConverterUsesRawResponse(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
require.NotNil(t, route.RerankResponseConverter)
|
||||
|
||||
raw := map[string]interface{}{"records": []interface{}{}}
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Vertex,
|
||||
RawResponse: raw,
|
||||
},
|
||||
}
|
||||
converted, err := route.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, raw, converted)
|
||||
}
|
||||
|
||||
func TestGenAIRerankResponseConverterFallsBackWhenNotVertex(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
require.NotNil(t, route.RerankResponseConverter)
|
||||
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
Results: []schemas.RerankResult{
|
||||
{Index: 0, RelevanceScore: 0.9},
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Cohere,
|
||||
},
|
||||
}
|
||||
converted, err := route.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp, converted)
|
||||
}
|
||||
|
||||
func TestCreateGenAIRouteConfigsIncludesModelMetadataRoute(t *testing.T) {
|
||||
routes := CreateGenAIRouteConfigs("/genai")
|
||||
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == "/genai/v1beta/models/{model}" && route.Method == "GET" {
|
||||
found = true
|
||||
assert.Equal(t, schemas.ListModelsRequest, route.GetHTTPRequestType(nil))
|
||||
require.NotNil(t, route.PreCallback)
|
||||
require.NotNil(t, route.ListModelsResponseConverter)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, found, "expected model metadata route in genai route configs")
|
||||
}
|
||||
|
||||
func TestExtractGeminiModelMetadataParams(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("model", "models/gemini-3-pro-preview")
|
||||
|
||||
listReq := &schemas.BifrostListModelsRequest{}
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
err := extractGeminiModelMetadataParams(ctx, bifrostCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemas.Gemini, listReq.Provider)
|
||||
assert.Equal(t, "/models/gemini-3-pro-preview", bifrostCtx.Value(schemas.BifrostContextKeyURLPath))
|
||||
assert.Equal(t, "gemini-3-pro-preview", bifrostCtx.Value(requestedGeminiModelMetadataContextKey))
|
||||
}
|
||||
|
||||
func TestConvertGeminiModelMetadataResponse(t *testing.T) {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-2.5-pro")
|
||||
|
||||
resp := &schemas.BifrostListModelsResponse{
|
||||
Data: []schemas.Model{{ID: "gemini/gemini-2.5-pro", Name: schemas.Ptr("Gemini 2.5 Pro")}},
|
||||
}
|
||||
|
||||
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, ok := converted.(gemini.GeminiModel)
|
||||
require.True(t, ok, "expected gemini.GeminiModel")
|
||||
assert.Equal(t, "models/gemini-2.5-pro", model.Name)
|
||||
assert.Equal(t, "Gemini 2.5 Pro", model.DisplayName)
|
||||
}
|
||||
|
||||
func TestConvertGeminiModelMetadataResponse_MatchesRequestedModelNotFirst(t *testing.T) {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-3-pro-preview")
|
||||
|
||||
resp := &schemas.BifrostListModelsResponse{
|
||||
Data: []schemas.Model{
|
||||
{ID: "gemini/gemini-1.5-pro", Name: schemas.Ptr("Gemini 1.5 Pro")},
|
||||
{ID: "gemini/gemini-3-pro-preview", Name: schemas.Ptr("Gemini 3 Pro Preview")},
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, ok := converted.(gemini.GeminiModel)
|
||||
require.True(t, ok, "expected gemini.GeminiModel")
|
||||
assert.Equal(t, "models/gemini-3-pro-preview", model.Name)
|
||||
assert.Equal(t, "Gemini 3 Pro Preview", model.DisplayName)
|
||||
}
|
||||
|
||||
func TestConvertGeminiModelMetadataResponse_EmptyReturnsMinimalModel(t *testing.T) {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-3-pro-preview")
|
||||
|
||||
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, &schemas.BifrostListModelsResponse{Data: []schemas.Model{}})
|
||||
require.NoError(t, err)
|
||||
model, ok := converted.(gemini.GeminiModel)
|
||||
require.True(t, ok, "expected gemini.GeminiModel")
|
||||
assert.Equal(t, "models/gemini-3-pro-preview", model.Name)
|
||||
}
|
||||
42
transports/bifrost-http/integrations/langchain.go
Normal file
42
transports/bifrost-http/integrations/langchain.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// LangChainRouter holds route registrations for LangChain endpoints.
|
||||
// It supports standard chat completions and image-enabled vision capabilities.
|
||||
// LangChain is fully OpenAI-compatible, so we reuse OpenAI types
|
||||
// with aliases for clarity and minimal LangChain-specific extensions
|
||||
type LangChainRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewLangChainRouter creates a new LangChainRouter with the given bifrost client.
|
||||
func NewLangChainRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *LangChainRouter {
|
||||
routes := []RouteConfig{}
|
||||
|
||||
// Add OpenAI routes to LangChain for OpenAI API compatibility
|
||||
routes = append(routes, CreateOpenAIRouteConfigs("/langchain", handlerStore)...)
|
||||
|
||||
// Add Anthropic routes to LangChain for Anthropic API compatibility
|
||||
routes = append(routes, CreateAnthropicRouteConfigs("/langchain", logger)...)
|
||||
|
||||
// Add Anthropic count tokens route for LangChain to ensure token counting uses the dedicated endpoint
|
||||
routes = append(routes, CreateAnthropicCountTokensRouteConfigs("/langchain", handlerStore)...)
|
||||
|
||||
// Add GenAI routes to LangChain for Vertex AI compatibility
|
||||
routes = append(routes, CreateGenAIRouteConfigs("/langchain")...)
|
||||
|
||||
// Add Bedrock routes to LangChain for AWS Bedrock API compatibility
|
||||
routes = append(routes, CreateBedrockRouteConfigs("/langchain", handlerStore)...)
|
||||
|
||||
// Add Cohere routes to LangChain for Cohere API compatibility
|
||||
routes = append(routes, CreateCohereRouteConfigs("/langchain")...)
|
||||
|
||||
return &LangChainRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
|
||||
}
|
||||
}
|
||||
39
transports/bifrost-http/integrations/litellm.go
Normal file
39
transports/bifrost-http/integrations/litellm.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// LiteLLMRouter holds route registrations for LiteLLM endpoints.
|
||||
// It supports standard chat completions and image-enabled vision capabilities.
|
||||
// LiteLLM is fully OpenAI-compatible, so we reuse OpenAI types
|
||||
// with aliases for clarity and minimal LiteLLM-specific extensions
|
||||
type LiteLLMRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewLiteLLMRouter creates a new LiteLLMRouter with the given bifrost client.
|
||||
func NewLiteLLMRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *LiteLLMRouter {
|
||||
routes := []RouteConfig{}
|
||||
|
||||
// Add OpenAI routes to LiteLLM for OpenAI API compatibility
|
||||
routes = append(routes, CreateOpenAIRouteConfigs("/litellm", handlerStore)...)
|
||||
|
||||
// Add Anthropic routes to LiteLLM for Anthropic API compatibility
|
||||
routes = append(routes, CreateAnthropicRouteConfigs("/litellm", logger)...)
|
||||
|
||||
// Add GenAI routes to LiteLLM for Vertex AI compatibility
|
||||
routes = append(routes, CreateGenAIRouteConfigs("/litellm")...)
|
||||
|
||||
// Add Bedrock routes to LiteLLM for AWS Bedrock API compatibility
|
||||
routes = append(routes, CreateBedrockRouteConfigs("/litellm", handlerStore)...)
|
||||
|
||||
// Add Cohere routes to LiteLLM for Cohere API compatibility
|
||||
routes = append(routes, CreateCohereRouteConfigs("/litellm")...)
|
||||
|
||||
return &LiteLLMRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
|
||||
}
|
||||
}
|
||||
3349
transports/bifrost-http/integrations/openai.go
Normal file
3349
transports/bifrost-http/integrations/openai.go
Normal file
File diff suppressed because it is too large
Load Diff
72
transports/bifrost-http/integrations/passthrough.go
Normal file
72
transports/bifrost-http/integrations/passthrough.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// PassthroughRouter is a catch-all router that forwards all requests directly
|
||||
// to the provider without matching against known route patterns.
|
||||
type PassthroughRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewPassthroughRouter creates a passthrough-only router for any prefix/provider combo.
|
||||
func NewPassthroughRouter(
|
||||
client *bifrost.Bifrost,
|
||||
handlerStore lib.HandlerStore,
|
||||
logger schemas.Logger,
|
||||
cfg *PassthroughConfig,
|
||||
) *PassthroughRouter {
|
||||
if cfg == nil {
|
||||
cfg = &PassthroughConfig{}
|
||||
}
|
||||
return &PassthroughRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, nil, cfg, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// NewAnthropicPassthroughRouter creates a passthrough router for /anthropic_passthrough.
|
||||
func NewAnthropicPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.Anthropic,
|
||||
StripPrefix: []string{
|
||||
"/anthropic_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewOpenAIPassthroughRouter creates a passthrough router for /openai_passthrough.
|
||||
func NewOpenAIPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.OpenAI,
|
||||
StripPrefix: []string{
|
||||
"/openai_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewAzurePassthroughRouter creates a passthrough router for /azure_passthrough.
|
||||
func NewAzurePassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.Azure,
|
||||
StripPrefix: []string{
|
||||
"/azure_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewGenAIPassthroughRouter creates a passthrough router for /genai_passthrough.
|
||||
func NewGenAIPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.Gemini,
|
||||
ProviderDetector: detectProviderFromGenAIRequest,
|
||||
StripPrefix: []string{
|
||||
"/genai_passthrough/v1beta1",
|
||||
"/genai_passthrough/v1beta",
|
||||
"/genai_passthrough/v1",
|
||||
"/genai_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
135
transports/bifrost-http/integrations/pydanticai.go
Normal file
135
transports/bifrost-http/integrations/pydanticai.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// PydanticAIRouter holds route registrations for Pydantic AI endpoints.
|
||||
// It supports standard chat completions, tool calling, streaming, and multi-provider capabilities.
|
||||
// Pydantic AI uses standard provider SDKs (OpenAI, Anthropic, Google GenAI), so we reuse
|
||||
// existing route configurations with aliases for clarity and Pydantic AI-specific extensions.
|
||||
type PydanticAIRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewPydanticAIRouter creates a new PydanticAIRouter with the given bifrost client.
|
||||
func NewPydanticAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PydanticAIRouter {
|
||||
routes := []RouteConfig{}
|
||||
// Add OpenAI routes to Pydantic AI for OpenAI API compatibility
|
||||
// Supports: chat completions, embeddings, speech, transcriptions, responses
|
||||
routes = append(routes, withPydanticResponsesNullNormalization(CreateOpenAIRouteConfigs("/pydanticai", handlerStore))...)
|
||||
// Add Anthropic routes to Pydantic AI for Anthropic API compatibility
|
||||
// Supports: messages API (Claude models)
|
||||
routes = append(routes, CreateAnthropicRouteConfigs("/pydanticai", logger)...)
|
||||
// Add GenAI routes to Pydantic AI for Google Gemini API compatibility
|
||||
// Supports: generateContent, streamGenerateContent, embedContent
|
||||
routes = append(routes, CreateGenAIRouteConfigs("/pydanticai")...)
|
||||
// Add Cohere routes to Pydantic AI for Cohere API compatibility
|
||||
// Supports: v2/chat (chat completions with streaming), v2/embed (embeddings)
|
||||
routes = append(routes, CreateCohereRouteConfigs("/pydanticai")...)
|
||||
// Add Bedrock routes to Pydantic AI for AWS Bedrock API compatibility
|
||||
// Supports: converse, converse-stream, invoke, invoke-with-response-stream
|
||||
routes = append(routes, CreateBedrockRouteConfigs("/pydanticai", handlerStore)...)
|
||||
return &PydanticAIRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
|
||||
}
|
||||
}
|
||||
|
||||
func withPydanticResponsesNullNormalization(routes []RouteConfig) []RouteConfig {
|
||||
for i := range routes {
|
||||
if !strings.Contains(routes[i].Path, "/responses") {
|
||||
continue
|
||||
}
|
||||
|
||||
if routes[i].ResponsesResponseConverter != nil {
|
||||
routes[i].ResponsesResponseConverter = func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) {
|
||||
// For pydantic responses endpoint, prefer normalized bifrost output
|
||||
// instead of raw passthrough, to keep null handling consistent.
|
||||
return resp.WithDefaults(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if routes[i].StreamConfig != nil && routes[i].StreamConfig.ResponsesStreamResponseConverter != nil {
|
||||
// Match non-stream behavior: prefer normalized output (raw->normalizePydanticResponsesRawStreamChunk, typed->resp.WithDefaults()+ensurePydanticResponsesStreamTextFields).
|
||||
routes[i].StreamConfig.ResponsesStreamResponseConverter = func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) {
|
||||
if resp == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
normalizedRaw := normalizePydanticResponsesRawStreamChunk(resp.ExtraFields.RawResponse)
|
||||
if normalizedRawString, ok := normalizedRaw.(string); ok {
|
||||
return string(resp.Type), normalizedRawString, nil
|
||||
}
|
||||
}
|
||||
|
||||
normalized := resp.WithDefaults()
|
||||
ensurePydanticResponsesStreamTextFields(normalized)
|
||||
return string(resp.Type), normalized, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func ensurePydanticResponsesStreamTextFields(resp *schemas.BifrostResponsesStreamResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch resp.Type {
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
|
||||
if resp.Delta == nil {
|
||||
resp.Delta = bifrost.Ptr("")
|
||||
}
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDone:
|
||||
if resp.Text == nil {
|
||||
resp.Text = bifrost.Ptr("")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePydanticResponsesRawStreamChunk(raw interface{}) interface{} {
|
||||
rawString, ok := raw.(string)
|
||||
if !ok {
|
||||
return raw
|
||||
}
|
||||
|
||||
var chunk map[string]interface{}
|
||||
if err := sonic.UnmarshalString(rawString, &chunk); err != nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
changed := false
|
||||
if chunkType, ok := chunk["type"].(string); ok {
|
||||
switch schemas.ResponsesStreamResponseType(chunkType) {
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
|
||||
if value, exists := chunk["delta"]; exists && value == nil {
|
||||
chunk["delta"] = ""
|
||||
changed = true
|
||||
}
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDone:
|
||||
if value, exists := chunk["text"]; exists && value == nil {
|
||||
chunk["text"] = ""
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return raw
|
||||
}
|
||||
|
||||
normalized, err := sonic.MarshalString(chunk)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
2862
transports/bifrost-http/integrations/router.go
Normal file
2862
transports/bifrost-http/integrations/router.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,152 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestCreateHandler_SkipsRequestParserInLargePayloadMode(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
parserCalls := 0
|
||||
|
||||
route := RouteConfig{
|
||||
Type: RouteConfigTypeOpenAI,
|
||||
Path: "/openai/v1/chat/completions",
|
||||
Method: "POST",
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.ChatCompletionRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &struct{}{}
|
||||
},
|
||||
RequestParser: func(ctx *fasthttp.RequestCtx, req interface{}) error {
|
||||
parserCalls++
|
||||
return nil
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
return nil, errors.New("stop after parse phase")
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router := NewGenericRouter(nil, handlerStore, nil, nil, nil)
|
||||
router.SetLargePayloadHook(func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, routeType RouteConfigType) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetBodyString(`{"model":"openai/gpt-4o","messages":[]}`)
|
||||
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, schemas.ChatCompletionRequest)
|
||||
|
||||
handler := router.createHandler(route)
|
||||
handler(ctx)
|
||||
|
||||
assert.Equal(t, 0, parserCalls)
|
||||
}
|
||||
|
||||
func TestCreateHandler_UsesRequestParserWhenNotInLargePayloadMode(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
parserCalls := 0
|
||||
|
||||
route := RouteConfig{
|
||||
Type: RouteConfigTypeOpenAI,
|
||||
Path: "/openai/v1/chat/completions",
|
||||
Method: "POST",
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.ChatCompletionRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &struct{}{}
|
||||
},
|
||||
RequestParser: func(ctx *fasthttp.RequestCtx, req interface{}) error {
|
||||
parserCalls++
|
||||
return nil
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
return nil, errors.New("stop after parse phase")
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router := NewGenericRouter(nil, handlerStore, nil, nil, nil)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetBodyString(`{"model":"openai/gpt-4o","messages":[]}`)
|
||||
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, schemas.ChatCompletionRequest)
|
||||
|
||||
handler := router.createHandler(route)
|
||||
handler(ctx)
|
||||
|
||||
assert.Equal(t, 1, parserCalls)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// resolveLargePayloadMetadata tests
|
||||
// ============================================================================
|
||||
|
||||
func TestResolveLargePayloadMetadata_NilContext(t *testing.T) {
|
||||
assert.Nil(t, resolveLargePayloadMetadata(nil))
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_SyncPath(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
meta := &schemas.LargePayloadMetadata{Model: "gpt-4o"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, meta)
|
||||
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "gpt-4o", result.Model)
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_DeferredReady(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
ch := make(chan *schemas.LargePayloadMetadata, 1)
|
||||
ch <- &schemas.LargePayloadMetadata{Model: "claude-4"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
|
||||
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "claude-4", result.Model)
|
||||
|
||||
// Verify it was cached in the sync key.
|
||||
cached, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-4", cached.Model)
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_DeferredNotReady(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
ch := make(chan *schemas.LargePayloadMetadata, 1) // empty, not ready
|
||||
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
|
||||
|
||||
// Non-blocking: should return nil when channel has no value yet.
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_SyncTakesPrecedence(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
syncMeta := &schemas.LargePayloadMetadata{Model: "sync-model"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, syncMeta)
|
||||
|
||||
ch := make(chan *schemas.LargePayloadMetadata, 1)
|
||||
ch <- &schemas.LargePayloadMetadata{Model: "deferred-model"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
|
||||
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "sync-model", result.Model)
|
||||
}
|
||||
373
transports/bifrost-http/integrations/router_test.go
Normal file
373
transports/bifrost-http/integrations/router_test.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"mime/multipart"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParsePassthroughBody_MultipartExtractsModelAfterFilePart(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
fileWriter, err := writer.CreateFormFile("file", "sample.mp3")
|
||||
require.NoError(t, err)
|
||||
_, err = fileWriter.Write([]byte("audio-bytes"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.WriteField("model", "openai/whisper-1"))
|
||||
require.NoError(t, writer.WriteField("stream", "true"))
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
model, stream := parsePassthroughBody(writer.FormDataContentType(), body.Bytes())
|
||||
assert.Equal(t, "openai/whisper-1", model)
|
||||
assert.True(t, stream)
|
||||
}
|
||||
|
||||
func TestRequestWithSettableExtraParams_OpenAIChatRequest(t *testing.T) {
|
||||
t.Run("SetExtraParams populates both standalone and embedded ExtraParams", func(t *testing.T) {
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
extra := map[string]interface{}{
|
||||
"guardrailConfig": map[string]interface{}{
|
||||
"guardrailIdentifier": "xxx",
|
||||
"guardrailVersion": "1",
|
||||
},
|
||||
}
|
||||
|
||||
rws, ok := interface{}(req).(RequestWithSettableExtraParams)
|
||||
require.True(t, ok, "OpenAIChatRequest should implement RequestWithSettableExtraParams")
|
||||
|
||||
rws.SetExtraParams(extra)
|
||||
|
||||
assert.Equal(t, extra, req.GetExtraParams())
|
||||
assert.Equal(t, extra, req.ChatParameters.ExtraParams, "embedded ChatParameters.ExtraParams should also be set")
|
||||
})
|
||||
|
||||
t.Run("extra params propagate through ToBifrostChatRequest", func(t *testing.T) {
|
||||
req := &openai.OpenAIChatRequest{
|
||||
Model: "bedrock/claude-4-5-sonnet-global",
|
||||
Messages: []openai.OpenAIMessage{},
|
||||
}
|
||||
extra := map[string]interface{}{
|
||||
"guardrailConfig": map[string]interface{}{
|
||||
"guardrailIdentifier": "test-id",
|
||||
"guardrailVersion": "1",
|
||||
},
|
||||
}
|
||||
|
||||
rws := interface{}(req).(RequestWithSettableExtraParams)
|
||||
rws.SetExtraParams(extra)
|
||||
|
||||
ctx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostReq := req.ToBifrostChatRequest(ctx)
|
||||
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.Params)
|
||||
assert.Contains(t, bifrostReq.Params.ExtraParams, "guardrailConfig")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestWithSettableExtraParams_AllOpenAIRequestTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req interface{}
|
||||
}{
|
||||
{"OpenAIChatRequest", &openai.OpenAIChatRequest{}},
|
||||
{"OpenAITextCompletionRequest", &openai.OpenAITextCompletionRequest{}},
|
||||
{"OpenAIResponsesRequest", &openai.OpenAIResponsesRequest{}},
|
||||
{"OpenAIEmbeddingRequest", &openai.OpenAIEmbeddingRequest{}},
|
||||
{"OpenAISpeechRequest", &openai.OpenAISpeechRequest{}},
|
||||
{"OpenAIImageGenerationRequest", &openai.OpenAIImageGenerationRequest{}},
|
||||
{"OpenAIImageEditRequest", &openai.OpenAIImageEditRequest{}},
|
||||
{"OpenAIImageVariationRequest", &openai.OpenAIImageVariationRequest{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name+" implements RequestWithSettableExtraParams", func(t *testing.T) {
|
||||
rws, ok := tt.req.(RequestWithSettableExtraParams)
|
||||
require.True(t, ok, "%s should implement RequestWithSettableExtraParams", tt.name)
|
||||
|
||||
extra := map[string]interface{}{"test_key": "test_value"}
|
||||
rws.SetExtraParams(extra)
|
||||
|
||||
getter, ok := tt.req.(interface{ GetExtraParams() map[string]interface{} })
|
||||
require.True(t, ok, "%s should implement GetExtraParams", tt.name)
|
||||
assert.Equal(t, extra, getter.GetExtraParams())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraParamsRequiresPassthroughHeader(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := CreateOpenAIRouteConfigs("/openai", handlerStore)
|
||||
|
||||
var chatRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/openai/v1/chat/completions" {
|
||||
chatRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, chatRoute, "should find /openai/v1/chat/completions route")
|
||||
|
||||
rawBody := []byte(`{
|
||||
"model": "bedrock/claude-4-5-sonnet-global",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"extra_params": {
|
||||
"guardrailConfig": {
|
||||
"guardrailIdentifier": "my-guardrail",
|
||||
"guardrailVersion": "1",
|
||||
"trace": "disabled"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
t.Run("extra_params NOT extracted without passthrough header", func(t *testing.T) {
|
||||
req := chatRoute.GetRequestTypeInstance(context.Background())
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
// Header not set -- simulate router logic
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := req.(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
_ = rws
|
||||
}
|
||||
}
|
||||
|
||||
openaiReq, ok := req.(*openai.OpenAIChatRequest)
|
||||
require.True(t, ok)
|
||||
assert.Empty(t, openaiReq.ChatParameters.ExtraParams,
|
||||
"ExtraParams should be empty when passthrough header is not set")
|
||||
})
|
||||
|
||||
t.Run("extra_params extracted with passthrough header", func(t *testing.T) {
|
||||
req := chatRoute.GetRequestTypeInstance(context.Background())
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := req.(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
openaiReq, ok := req.(*openai.OpenAIChatRequest)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, openaiReq.ChatParameters.ExtraParams, "guardrailConfig",
|
||||
"guardrailConfig should be in ExtraParams when passthrough header is set")
|
||||
|
||||
gc, ok := openaiReq.ChatParameters.ExtraParams["guardrailConfig"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my-guardrail", gc["guardrailIdentifier"])
|
||||
assert.Equal(t, "1", gc["guardrailVersion"])
|
||||
assert.Equal(t, "disabled", gc["trace"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtraParamsPassthrough_NestedStructures(t *testing.T) {
|
||||
rawBody := []byte(`{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"extra_params": {
|
||||
"custom_param": "value",
|
||||
"another_param": 123,
|
||||
"nested": {
|
||||
"deep_field": "deep_value",
|
||||
"deeper": {"level": 3}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, req.ChatParameters.ExtraParams, 3)
|
||||
assert.Equal(t, "value", req.ChatParameters.ExtraParams["custom_param"])
|
||||
assert.Equal(t, float64(123), req.ChatParameters.ExtraParams["another_param"])
|
||||
|
||||
nested, ok := req.ChatParameters.ExtraParams["nested"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "deep_value", nested["deep_field"])
|
||||
}
|
||||
|
||||
func TestExtraParamsPassthrough_EndToEnd(t *testing.T) {
|
||||
rawJSON := []byte(`{
|
||||
"model": "bedrock/claude-4-5-sonnet-global",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"stream": false,
|
||||
"temperature": 0.7,
|
||||
"extra_params": {
|
||||
"guardrailConfig": {
|
||||
"guardrailIdentifier": "my-guardrail",
|
||||
"guardrailVersion": "1",
|
||||
"trace": "disabled"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
err := sonic.Unmarshal(rawJSON, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bedrock/claude-4-5-sonnet-global", req.Model)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawJSON, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostReq := req.ToBifrostChatRequest(bifrostCtx)
|
||||
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.Params)
|
||||
require.Contains(t, bifrostReq.Params.ExtraParams, "guardrailConfig")
|
||||
|
||||
gc, ok := bifrostReq.Params.ExtraParams["guardrailConfig"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my-guardrail", gc["guardrailIdentifier"])
|
||||
assert.Equal(t, "1", gc["guardrailVersion"])
|
||||
assert.Equal(t, "disabled", gc["trace"])
|
||||
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "model")
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "messages")
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "stream")
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "temperature")
|
||||
}
|
||||
|
||||
func TestExtraParamsPassthrough_NoExtraParamsKey(t *testing.T) {
|
||||
rawBody := []byte(`{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
|
||||
}`)
|
||||
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
_ = rws
|
||||
}
|
||||
}
|
||||
|
||||
assert.Empty(t, req.ChatParameters.ExtraParams,
|
||||
"ExtraParams should be empty when extra_params key is absent from JSON")
|
||||
}
|
||||
|
||||
// TestExtraParamsSetViaInterfaceMutatesOriginalReq verifies that setting extra
|
||||
// params through the RequestWithSettableExtraParams interface assertion mutates
|
||||
// the original req (interface{}) value. This matters because createHandler
|
||||
// passes req to config.RequestConverter after the extra params block -- both
|
||||
// variables must reference the same underlying struct via pointer semantics.
|
||||
func TestExtraParamsSetViaInterfaceMutatesOriginalReq(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := CreateOpenAIRouteConfigs("/openai", handlerStore)
|
||||
|
||||
var chatRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/openai/v1/chat/completions" {
|
||||
chatRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, chatRoute)
|
||||
|
||||
rawBody := []byte(`{
|
||||
"model": "bedrock/claude-4-5-sonnet-global",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"extra_params": {
|
||||
"guardrailConfig": {
|
||||
"guardrailIdentifier": "my-guardrail",
|
||||
"guardrailVersion": "1"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
// Simulate the exact flow in createHandler:
|
||||
// 1. req is created via GetRequestTypeInstance (returns interface{})
|
||||
// 2. JSON is unmarshalled into req
|
||||
// 3. rws type assertion is used to call SetExtraParams
|
||||
// 4. req (not rws) is passed to RequestConverter downstream
|
||||
req := chatRoute.GetRequestTypeInstance(context.Background()) // returns interface{}
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Type-assert and set extra params (same as router code)
|
||||
if rws, ok := req.(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that req (the original interface{} variable) was mutated
|
||||
openaiReq, ok := req.(*openai.OpenAIChatRequest)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, openaiReq.ChatParameters.ExtraParams, "guardrailConfig",
|
||||
"original req should be mutated via pointer semantics")
|
||||
|
||||
// Verify the full downstream path: RequestConverter uses req
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostReq, err := chatRoute.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.ChatRequest)
|
||||
require.NotNil(t, bifrostReq.ChatRequest.Params)
|
||||
assert.Contains(t, bifrostReq.ChatRequest.Params.ExtraParams, "guardrailConfig",
|
||||
"extra params should propagate through RequestConverter to BifrostChatRequest")
|
||||
}
|
||||
502
transports/bifrost-http/integrations/utils.go
Normal file
502
transports/bifrost-http/integrations/utils.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/gemini"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var bifrostContextKeyProvider = schemas.BifrostContextKey("provider")
|
||||
|
||||
var availableIntegrations = []string{
|
||||
"openai",
|
||||
"anthropic",
|
||||
"genai",
|
||||
"litellm",
|
||||
"langchain",
|
||||
"bedrock",
|
||||
"pydantic",
|
||||
"cohere",
|
||||
}
|
||||
|
||||
// newBifrostErrorWithCode is like newBifrostError but sets an explicit HTTP status code.
|
||||
func newBifrostErrorWithCode(err error, message string, statusCode int) *schemas.BifrostError {
|
||||
e := newBifrostError(err, message)
|
||||
e.StatusCode = &statusCode
|
||||
return e
|
||||
}
|
||||
|
||||
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
|
||||
// This helper function reduces code duplication when handling non-Bifrost errors.
|
||||
func newBifrostError(err error, message string) *schemas.BifrostError {
|
||||
if err == nil {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// safeGetRequestType safely obtains the request type from a BifrostStreamChunk chunk.
|
||||
// It checks multiple sources in order of preference:
|
||||
// 1. Response ExtraFields if any response is available
|
||||
// 2. BifrostError ExtraFields if error is available and not nil
|
||||
// 3. Falls back to "unknown" if no source is available
|
||||
func safeGetRequestType(chunk *schemas.BifrostStreamChunk) string {
|
||||
if chunk == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try to get RequestType from response ExtraFields (preferred source)
|
||||
switch {
|
||||
case chunk.BifrostTextCompletionResponse != nil:
|
||||
return string(chunk.BifrostTextCompletionResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostChatResponse != nil:
|
||||
return string(chunk.BifrostChatResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostResponsesStreamResponse != nil:
|
||||
return string(chunk.BifrostResponsesStreamResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostSpeechStreamResponse != nil:
|
||||
return string(chunk.BifrostSpeechStreamResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostTranscriptionStreamResponse != nil:
|
||||
return string(chunk.BifrostTranscriptionStreamResponse.ExtraFields.RequestType)
|
||||
}
|
||||
|
||||
// Try to get RequestType from error ExtraFields (fallback)
|
||||
if chunk.BifrostError != nil && chunk.BifrostError.ExtraFields.RequestType != "" {
|
||||
return string(chunk.BifrostError.ExtraFields.RequestType)
|
||||
}
|
||||
|
||||
// Final fallback
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractHeadersFromRequest extracts headers from the request and returns them as a map.
|
||||
// It uses the fasthttp.RequestCtx.Header.All() method to iterate over all headers.
|
||||
func extractHeadersFromRequest(ctx *fasthttp.RequestCtx) map[string][]string {
|
||||
headers := make(map[string][]string)
|
||||
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
keyStr := string(key)
|
||||
headers[keyStr] = append(headers[keyStr], string(value))
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// extractExactPath returns the request path *after* the integration prefix,
|
||||
// preserving the original query string exactly as sent by the client.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// /openai/v1/chat/completions?model=gpt-4o -> v1/chat/completions?model=gpt-4o
|
||||
func extractExactPath(ctx *fasthttp.RequestCtx) string {
|
||||
// ctx.Path() returns only the path (no query) as a []byte backed by fasthttp’s internal buffers.
|
||||
// Treat it as read-only; don’t append to it directly.
|
||||
path := ctx.Path() // e.g. "/openai/v1/chat/completions"
|
||||
|
||||
// Strip the integration prefix only if it’s at the start.
|
||||
for _, integration := range availableIntegrations {
|
||||
if bytes.HasPrefix(path, []byte("/"+integration+"/")) {
|
||||
path = path[len("/"+integration+"/"):]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Raw query string as sent by client (unparsed, preserves ordering/duplicates/encoding).
|
||||
q := ctx.URI().QueryString() // e.g. "model=gpt-4o&stream=true"
|
||||
|
||||
if len(q) == 0 {
|
||||
// No query → just return the (possibly trimmed) path.
|
||||
return string(path)
|
||||
}
|
||||
|
||||
// --- Build "<path>?<query>" efficiently and safely ---
|
||||
//
|
||||
// Why not do: return string(path) + "?" + string(q) ?
|
||||
// - That allocates multiple temporary strings and may copy data more than necessary.
|
||||
//
|
||||
// Why not append into 'path' directly?
|
||||
// - 'path' may alias fasthttp’s internal buffers; mutating/expanding it could corrupt request state.
|
||||
//
|
||||
// We instead allocate a new buffer with exact capacity and copy into it,
|
||||
// staying in []byte until the final string conversion (1 allocation for the new slice).
|
||||
out := make([]byte, 0, len(path)+1+len(q)) // pre-size: path + "?" + query
|
||||
out = append(out, path...) // copy path bytes
|
||||
out = append(out, '?') // separator
|
||||
out = append(out, q...) // copy raw query bytes
|
||||
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// sendStreamError sends an error response for a streaming request that failed before streaming started.
|
||||
// It propagates the provider's HTTP status code and returns a JSON error body (not SSE format),
|
||||
// since no streaming has begun and clients should receive a standard error response.
|
||||
func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, bifrostErr *schemas.BifrostError) {
|
||||
// Forward provider response headers from context so streaming error responses include them
|
||||
if bifrostCtx != nil {
|
||||
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
|
||||
for key, value := range headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the HTTP status code from the provider error
|
||||
if bifrostErr.StatusCode != nil {
|
||||
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
// Always use the route-level ErrorConverter (not StreamConfig.ErrorConverter) because
|
||||
// sendStreamError returns JSON, not SSE. StreamConfig.ErrorConverter is designed for
|
||||
// in-stream SSE errors (e.g., Anthropic's returns a raw SSE string that would be
|
||||
// double-escaped by JSON marshaling).
|
||||
errorResponse := config.ErrorConverter(bifrostCtx, bifrostErr)
|
||||
|
||||
errorJSON, err := sonic.Marshal(errorResponse)
|
||||
if err != nil {
|
||||
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetContentType("text/plain; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetBody(errorJSON)
|
||||
}
|
||||
|
||||
// sendError sends an error response with the appropriate status code and JSON body.
|
||||
// It handles different error types (string, error interface, or arbitrary objects).
|
||||
func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) {
|
||||
// Forward provider response headers from context so error responses include them
|
||||
if bifrostCtx != nil {
|
||||
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
|
||||
for key, value := range headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostErr.StatusCode != nil {
|
||||
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
// Marshal the error for response and log the error for diagnostics
|
||||
responseObj := errorConverter(bifrostCtx, bifrostErr)
|
||||
errorBody, err := sonic.Marshal(responseObj)
|
||||
if err != nil {
|
||||
// Log the marshal failure and return a plain text error
|
||||
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetContentType("text/plain; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetBody(errorBody)
|
||||
}
|
||||
|
||||
// sendSuccess sends a successful response with HTTP 200 status and JSON body.
|
||||
func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, response interface{}, extraHeaders map[string]string) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
if extraHeaders != nil {
|
||||
for key, value := range extraHeaders {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
responseBody, err := sonic.Marshal(response)
|
||||
if err != nil {
|
||||
g.sendError(ctx, bifrostCtx, errorConverter, newBifrostError(err, "failed to encode response"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetBody(responseBody)
|
||||
}
|
||||
|
||||
// tryStreamLargeResponse checks if large response mode was activated by the provider,
|
||||
// sets the transport marker, and streams the response directly to the client.
|
||||
// Returns true if the response was handled (caller should return).
|
||||
func (g *GenericRouter) tryStreamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
|
||||
if !ok || !isLargeResponse {
|
||||
return false
|
||||
}
|
||||
// Forward provider response headers before streaming — providers store them in
|
||||
// context via BifrostContextKeyProviderResponseHeaders, but some early-return
|
||||
// branches in the router skip the common footer that normally forwards them.
|
||||
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
|
||||
for key, value := range headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
if g.streamLargeResponse(ctx, bifrostCtx) {
|
||||
ctx.SetUserValue(lib.FastHTTPUserValueLargeResponseMode, true)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// streamLargeResponse streams the large response body directly from the upstream provider to the client.
|
||||
// This bypasses the normal serialize → set body path, piping the response bytes unchanged.
|
||||
func (g *GenericRouter) streamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
// Enterprise hook: wrap the reader with Phase B scanning (e.g., usage extraction
|
||||
// from the full response stream) before streaming to client.
|
||||
if g.largeResponseHook != nil {
|
||||
g.largeResponseHook(ctx, bifrostCtx)
|
||||
}
|
||||
|
||||
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString("large response reader not available")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// extractAndParseFallbacks extracts fallbacks from the integration request and adds them to the BifrostRequest
|
||||
func (g *GenericRouter) extractAndParseFallbacks(req interface{}, bifrostReq *schemas.BifrostRequest) error {
|
||||
// Check if the request has a fallbacks field ([]string)
|
||||
fallbacks, err := g.extractFallbacksFromRequest(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract fallbacks: %w", err)
|
||||
}
|
||||
|
||||
if len(fallbacks) == 0 {
|
||||
return nil // No fallbacks to process
|
||||
}
|
||||
|
||||
provider, _, _ := bifrostReq.GetRequestFields()
|
||||
|
||||
// Parse fallbacks from strings to Fallback structs
|
||||
parsedFallbacks := make([]schemas.Fallback, 0, len(fallbacks))
|
||||
for _, fallbackStr := range fallbacks {
|
||||
if fallbackStr == "" {
|
||||
continue // Skip empty strings
|
||||
}
|
||||
|
||||
// Use ParseModelString to extract provider and model
|
||||
provider, model := schemas.ParseModelString(fallbackStr, provider)
|
||||
|
||||
parsedFallback := schemas.Fallback{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
parsedFallbacks = append(parsedFallbacks, parsedFallback)
|
||||
}
|
||||
|
||||
if len(parsedFallbacks) == 0 {
|
||||
return nil // No valid fallbacks found
|
||||
}
|
||||
|
||||
// Add fallbacks to the main BifrostRequest
|
||||
bifrostReq.SetFallbacks(parsedFallbacks)
|
||||
|
||||
// Also add fallbacks to the specific request type if it exists
|
||||
switch bifrostReq.RequestType {
|
||||
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
|
||||
if bifrostReq.TextCompletionRequest != nil {
|
||||
bifrostReq.TextCompletionRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
if bifrostReq.ChatRequest != nil {
|
||||
bifrostReq.ChatRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
|
||||
if bifrostReq.ResponsesRequest != nil {
|
||||
bifrostReq.ResponsesRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.EmbeddingRequest:
|
||||
if bifrostReq.EmbeddingRequest != nil {
|
||||
bifrostReq.EmbeddingRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.RerankRequest:
|
||||
if bifrostReq.RerankRequest != nil {
|
||||
bifrostReq.RerankRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
|
||||
if bifrostReq.SpeechRequest != nil {
|
||||
bifrostReq.SpeechRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
|
||||
if bifrostReq.TranscriptionRequest != nil {
|
||||
bifrostReq.TranscriptionRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest:
|
||||
if bifrostReq.ImageGenerationRequest != nil {
|
||||
bifrostReq.ImageGenerationRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFallbacksFromRequest uses reflection to extract fallbacks field from any request type
|
||||
func (g *GenericRouter) extractFallbacksFromRequest(req interface{}) ([]string, error) {
|
||||
if req == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Try to use reflection to find a "fallbacks" field
|
||||
reqValue := reflect.ValueOf(req)
|
||||
if reqValue.Kind() == reflect.Ptr {
|
||||
reqValue = reqValue.Elem()
|
||||
}
|
||||
|
||||
if reqValue.Kind() != reflect.Struct {
|
||||
return nil, nil // Not a struct, no fallbacks
|
||||
}
|
||||
|
||||
// Look for the "fallbacks" field
|
||||
fallbacksField := reqValue.FieldByName("fallbacks")
|
||||
if !fallbacksField.IsValid() {
|
||||
return nil, nil // No fallbacks field found
|
||||
}
|
||||
|
||||
// Handle different types of fallbacks field
|
||||
switch fallbacksField.Kind() {
|
||||
case reflect.Slice:
|
||||
if fallbacksField.Type().Elem().Kind() == reflect.String {
|
||||
// []string case
|
||||
fallbacks := make([]string, fallbacksField.Len())
|
||||
for i := 0; i < fallbacksField.Len(); i++ {
|
||||
fallbacks[i] = fallbacksField.Index(i).String()
|
||||
}
|
||||
return fallbacks, nil
|
||||
}
|
||||
case reflect.String:
|
||||
// Single string case - treat as one fallback
|
||||
return []string{fallbacksField.String()}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getVirtualKeyFromBifrostContext extracts the virtual key value from bifrost context.
|
||||
// Returns nil if no VK is present (e.g., direct key mode or no governance).
|
||||
func getVirtualKeyFromBifrostContext(ctx *schemas.BifrostContext) *string {
|
||||
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
if vkValue == "" {
|
||||
return nil
|
||||
}
|
||||
return &vkValue
|
||||
}
|
||||
|
||||
// getResultTTLFromHeaderWithDefault extracts the result TTL from the x-bf-async-job-result-ttl header.
|
||||
// Returns the default TTL if the header is not present or invalid.
|
||||
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
|
||||
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
|
||||
if resultTTL == "" {
|
||||
return defaultTTL
|
||||
}
|
||||
resultTTLInt, err := strconv.Atoi(resultTTL)
|
||||
if err != nil || resultTTLInt < 0 {
|
||||
return defaultTTL
|
||||
}
|
||||
return resultTTLInt
|
||||
}
|
||||
|
||||
// isAnthropicAPIKeyAuth checks if the request uses standard API key authentication.
|
||||
// Returns true for API key auth (x-api-key header), false for OAuth (Bearer sk-ant-oat*).
|
||||
// This is required for Claude Code specifically, which may use OAuth authentication.
|
||||
// Default behavior is to assume API mode when neither x-api-key nor OAuth token is present.
|
||||
func isAnthropicAPIKeyAuth(ctx *fasthttp.RequestCtx) bool {
|
||||
// If x-api-key header is present - this is definitely API mode
|
||||
if apiKey := string(ctx.Request.Header.Peek("x-api-key")); apiKey != "" {
|
||||
return true
|
||||
}
|
||||
// Check for OAuth token in Authorization header
|
||||
if authHeader := string(ctx.Request.Header.Peek("Authorization")); authHeader != "" {
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer sk-ant-oat") {
|
||||
return false // OAuth mode, NOT API
|
||||
}
|
||||
}
|
||||
// Default to API mode
|
||||
return true
|
||||
}
|
||||
|
||||
// resolveLargePayloadMetadata returns metadata from the sync context key,
|
||||
// falling back to a non-blocking read from the deferred channel.
|
||||
// If deferred metadata is resolved, it is cached in the sync key for later readers.
|
||||
func resolveLargePayloadMetadata(bifrostCtx *schemas.BifrostContext) *schemas.LargePayloadMetadata {
|
||||
if bifrostCtx == nil {
|
||||
return nil
|
||||
}
|
||||
if metadata, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata); ok && metadata != nil {
|
||||
return metadata
|
||||
}
|
||||
ch, ok := bifrostCtx.Value(schemas.BifrostContextKeyDeferredLargePayloadMetadata).(<-chan *schemas.LargePayloadMetadata)
|
||||
if !ok || ch == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case metadata := <-ch:
|
||||
if metadata != nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, metadata)
|
||||
}
|
||||
return metadata
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ParseProviderScopedVideoID parses a provider-scoped video ID in the form "id:provider".
|
||||
// The ID portion is automatically URL-decoded to restore the original ID.
|
||||
func ParseProviderScopedVideoID(videoID string) (schemas.ModelProvider, string, error) {
|
||||
parts := strings.SplitN(videoID, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", fmt.Errorf("video_id must be in id:provider format")
|
||||
}
|
||||
provider := schemas.ModelProvider(parts[1])
|
||||
rawID := parts[0]
|
||||
|
||||
// URL decode the ID to restore original characters (e.g., %2F -> /)
|
||||
// This handles IDs from all providers that may contain special characters
|
||||
if decoded, err := url.PathUnescape(rawID); err == nil {
|
||||
rawID = decoded
|
||||
}
|
||||
|
||||
return provider, rawID, nil
|
||||
}
|
||||
|
||||
func getProviderFromHeader(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) schemas.ModelProvider {
|
||||
providerHeader := string(ctx.Request.Header.Peek("x-model-provider"))
|
||||
if providerHeader == "" {
|
||||
return defaultProvider
|
||||
}
|
||||
return schemas.ModelProvider(providerHeader)
|
||||
}
|
||||
|
||||
func RegisterKVDecoders(store *kvstore.Store) {
|
||||
store.RegisterDecoder("genai_upload_session:", func(data []byte) (any, error) {
|
||||
var v gemini.GeminiResumableUploadSession
|
||||
return &v, sonic.Unmarshal(data, &v)
|
||||
})
|
||||
}
|
||||
277
transports/bifrost-http/integrations/utils_test.go
Normal file
277
transports/bifrost-http/integrations/utils_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/providers/bedrock"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// testLogger implements schemas.Logger for testing (all no-ops)
|
||||
type testLogger struct{}
|
||||
|
||||
func (t *testLogger) Debug(msg string, args ...any) {}
|
||||
func (t *testLogger) Info(msg string, args ...any) {}
|
||||
func (t *testLogger) Warn(msg string, args ...any) {}
|
||||
func (t *testLogger) Error(msg string, args ...any) {}
|
||||
func (t *testLogger) Fatal(msg string, args ...any) {}
|
||||
func (t *testLogger) SetLevel(level schemas.LogLevel) {}
|
||||
func (t *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
||||
func (t *testLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
var _ schemas.Logger = (*testLogger)(nil)
|
||||
|
||||
func ptr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func newTestGenericRouter() *GenericRouter {
|
||||
return NewGenericRouter(nil, &mockHandlerStore{}, nil, nil, &testLogger{})
|
||||
}
|
||||
|
||||
func newTestBifrostContext() *schemas.BifrostContext {
|
||||
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
}
|
||||
|
||||
// TestSendStreamError_PropagatesProviderStatusCode verifies that sendStreamError
|
||||
// sets the HTTP status code from the provider's BifrostError.StatusCode field.
|
||||
// All three providers (OpenAI, Anthropic, Bedrock) return actual HTTP error codes
|
||||
// for pre-stream errors, so Bifrost must propagate them faithfully.
|
||||
func TestSendStreamError_PropagatesProviderStatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode *int
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "provider 400 - Bedrock ValidationException / OpenAI invalid_request_error",
|
||||
statusCode: ptr(400),
|
||||
expectedStatusCode: 400,
|
||||
},
|
||||
{
|
||||
name: "provider 429 - rate limiting (all providers)",
|
||||
statusCode: ptr(429),
|
||||
expectedStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "provider 503 - Bedrock ServiceUnavailableException",
|
||||
statusCode: ptr(503),
|
||||
expectedStatusCode: 503,
|
||||
},
|
||||
{
|
||||
name: "provider 529 - Anthropic overloaded_error",
|
||||
statusCode: ptr(529),
|
||||
expectedStatusCode: 529,
|
||||
},
|
||||
{
|
||||
name: "nil StatusCode defaults to 500",
|
||||
statusCode: nil,
|
||||
expectedStatusCode: 500,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: tt.statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "test error",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, tt.expectedStatusCode, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType()))
|
||||
|
||||
body := string(ctx.Response.Body())
|
||||
assert.True(t, sonic.Valid(ctx.Response.Body()), "response body should be valid JSON, got: %s", body)
|
||||
assert.False(t, strings.HasPrefix(body, "data: "), "response should not be SSE format")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendStreamError_OpenAIErrorFormat verifies the response body matches the
|
||||
// OpenAI error format. OpenAI's ErrorConverter returns *schemas.BifrostError directly,
|
||||
// which serializes to {"is_bifrost_error":false,"status_code":400,"error":{...}}.
|
||||
func TestSendStreamError_OpenAIErrorFormat(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr("invalid_request_error"),
|
||||
Message: "content is empty",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 400, ctx.Response.StatusCode())
|
||||
|
||||
// Unmarshal and verify the structure
|
||||
var result map[string]interface{}
|
||||
err := sonic.Unmarshal(ctx.Response.Body(), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, result, "is_bifrost_error")
|
||||
assert.Contains(t, result, "status_code")
|
||||
assert.Contains(t, result, "error")
|
||||
assert.Equal(t, false, result["is_bifrost_error"])
|
||||
|
||||
errorObj, ok := result["error"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "invalid_request_error", errorObj["type"])
|
||||
assert.Equal(t, "content is empty", errorObj["message"])
|
||||
}
|
||||
|
||||
// TestSendStreamError_AnthropicErrorFormat verifies the response body matches the
|
||||
// Anthropic error format: {"type":"error","error":{"type":"...","message":"..."}}.
|
||||
// Critically, it also verifies that the StreamConfig.ErrorConverter (which returns
|
||||
// raw SSE strings) is NOT used — sendStreamError must use the route-level ErrorConverter.
|
||||
func TestSendStreamError_AnthropicErrorFormat(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: ptr(429),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr("rate_limit_error"),
|
||||
Message: "rate limited",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
// Route-level: returns JSON-marshallable *AnthropicMessageError
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return anthropic.ToAnthropicChatCompletionError(err)
|
||||
},
|
||||
// Stream-level: returns raw SSE string — should NOT be used by sendStreamError
|
||||
StreamConfig: &StreamConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return anthropic.ToAnthropicResponsesStreamError(err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 429, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType()))
|
||||
|
||||
body := string(ctx.Response.Body())
|
||||
|
||||
// Must NOT contain SSE markers — that would mean StreamConfig.ErrorConverter was used
|
||||
assert.NotContains(t, body, "event: error", "response should not contain SSE event markers")
|
||||
|
||||
// Unmarshal and verify Anthropic error structure
|
||||
var result anthropic.AnthropicMessageError
|
||||
err := sonic.Unmarshal(ctx.Response.Body(), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "error", result.Type)
|
||||
assert.Equal(t, "rate_limit_error", result.Error.Type)
|
||||
assert.Equal(t, "rate limited", result.Error.Message)
|
||||
}
|
||||
|
||||
// TestSendStreamError_BedrockErrorFormat verifies the response body matches the
|
||||
// Bedrock error format: {"__type":"ValidationException","message":"..."}.
|
||||
func TestSendStreamError_BedrockErrorFormat(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Code: strPtr("ValidationException"),
|
||||
Message: "validation error",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return bedrock.ToBedrockError(err)
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 400, ctx.Response.StatusCode())
|
||||
|
||||
// Unmarshal and verify Bedrock error structure
|
||||
var result bedrock.BedrockError
|
||||
err := sonic.Unmarshal(ctx.Response.Body(), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ValidationException", result.Type)
|
||||
assert.Equal(t, "validation error", result.Message)
|
||||
}
|
||||
|
||||
// TestSendStreamError_ForwardsProviderHeaders verifies that provider response headers
|
||||
// stored in the BifrostContext are forwarded to the HTTP response. This ensures
|
||||
// clients receive provider-specific headers (e.g., x-amzn-requestid for Bedrock,
|
||||
// x-request-id for Anthropic) even in error scenarios.
|
||||
func TestSendStreamError_ForwardsProviderHeaders(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
// Set provider response headers on the context
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, map[string]string{
|
||||
"x-amzn-requestid": "req-123",
|
||||
"x-amzn-errortype": "ValidationException",
|
||||
})
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "validation error",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 400, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "req-123", string(ctx.Response.Header.Peek("x-amzn-requestid")))
|
||||
assert.Equal(t, "ValidationException", string(ctx.Response.Header.Peek("x-amzn-errortype")))
|
||||
}
|
||||
105
transports/bifrost-http/lib/account.go
Normal file
105
transports/bifrost-http/lib/account.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// Package lib provides core functionality for the Bifrost HTTP service,
|
||||
// including context propagation, header management, and integration with monitoring systems.
|
||||
package lib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BaseAccount implements the Account interface for Bifrost.
|
||||
// It manages provider configurations using a in-memory store for persistent storage.
|
||||
// All data processing (environment variables, key configs) is done upfront in the store.
|
||||
type BaseAccount struct {
|
||||
store *Config // store for in-memory configuration
|
||||
}
|
||||
|
||||
// NewBaseAccount creates a new BaseAccount with the given store
|
||||
func NewBaseAccount(store *Config) *BaseAccount {
|
||||
return &BaseAccount{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfiguredProviders returns a list of all configured providers.
|
||||
// Implements the Account interface.
|
||||
func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
|
||||
if baseAccount.store == nil {
|
||||
return nil, fmt.Errorf("store not initialized")
|
||||
}
|
||||
return baseAccount.store.GetAllProviders()
|
||||
}
|
||||
|
||||
// GetKeysForProvider returns the API keys configured for a specific provider.
|
||||
// Keys are already processed (environment variables resolved) by the store.
|
||||
// Implements the Account interface.
|
||||
func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) {
|
||||
if baseAccount.store == nil {
|
||||
return nil, fmt.Errorf("store not initialized")
|
||||
}
|
||||
config, err := baseAccount.store.GetProviderConfigRaw(providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys := config.Keys
|
||||
if v := ctx.Value(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys); v != nil {
|
||||
if includeOnlyKeys, ok := v.([]string); ok {
|
||||
if len(includeOnlyKeys) == 0 {
|
||||
// header present but empty means "no keys allowed"
|
||||
keys = nil
|
||||
} else {
|
||||
set := make(map[string]struct{}, len(includeOnlyKeys))
|
||||
for _, id := range includeOnlyKeys {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
filtered := make([]schemas.Key, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if _, ok := set[key.ID]; ok {
|
||||
filtered = append(filtered, key)
|
||||
}
|
||||
}
|
||||
keys = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetConfigForProvider returns the complete configuration for a specific provider.
|
||||
// Configuration is already fully processed (environment variables, key configs) by the store.
|
||||
// Implements the Account interface.
|
||||
func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) {
|
||||
if baseAccount.store == nil {
|
||||
return nil, fmt.Errorf("store not initialized")
|
||||
}
|
||||
config, err := baseAccount.store.GetProviderConfigRaw(providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerConfig := &schemas.ProviderConfig{}
|
||||
if config.ProxyConfig != nil {
|
||||
providerConfig.ProxyConfig = config.ProxyConfig
|
||||
}
|
||||
if config.NetworkConfig != nil {
|
||||
providerConfig.NetworkConfig = *config.NetworkConfig
|
||||
} else {
|
||||
providerConfig.NetworkConfig = schemas.DefaultNetworkConfig
|
||||
}
|
||||
if config.ConcurrencyAndBufferSize != nil {
|
||||
providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize
|
||||
} else {
|
||||
providerConfig.ConcurrencyAndBufferSize = schemas.DefaultConcurrencyAndBufferSize
|
||||
}
|
||||
providerConfig.SendBackRawRequest = config.SendBackRawRequest
|
||||
providerConfig.SendBackRawResponse = config.SendBackRawResponse
|
||||
providerConfig.StoreRawRequestResponse = config.StoreRawRequestResponse
|
||||
if config.CustomProviderConfig != nil {
|
||||
providerConfig.CustomProviderConfig = config.CustomProviderConfig
|
||||
}
|
||||
if config.OpenAIConfig != nil {
|
||||
providerConfig.OpenAIConfig = config.OpenAIConfig
|
||||
}
|
||||
return providerConfig, nil
|
||||
}
|
||||
4310
transports/bifrost-http/lib/config.go
Normal file
4310
transports/bifrost-http/lib/config.go
Normal file
File diff suppressed because it is too large
Load Diff
17838
transports/bifrost-http/lib/config_test.go
Normal file
17838
transports/bifrost-http/lib/config_test.go
Normal file
File diff suppressed because it is too large
Load Diff
644
transports/bifrost-http/lib/ctx.go
Normal file
644
transports/bifrost-http/lib/ctx.go
Normal file
@@ -0,0 +1,644 @@
|
||||
// Package lib provides core functionality for the Bifrost HTTP service,
|
||||
// including context propagation, header management, and integration with monitoring systems.
|
||||
//
|
||||
// This package handles the conversion of FastHTTP request contexts to Bifrost contexts,
|
||||
// ensuring that important metadata and tracking information is preserved across the system.
|
||||
// It supports propagation of both Prometheus metrics and Maxim tracing data through HTTP headers.
|
||||
package lib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/plugins/maxim"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// FastHTTPUserValueBifrostContext stores the active *schemas.BifrostContext on fasthttp.RequestCtx.
|
||||
// This allows transport middleware and request handlers to share the same context instance.
|
||||
FastHTTPUserValueBifrostContext = "__bifrost_context"
|
||||
// FastHTTPUserValueBifrostCancel stores the cancel func for the active shared Bifrost context.
|
||||
FastHTTPUserValueBifrostCancel = "__bifrost_context_cancel"
|
||||
// FastHTTPUserValueLargeResponseMode marks requests that streamed a large response body.
|
||||
// It is used by transport middleware to avoid re-buffering response bodies for post-hooks.
|
||||
FastHTTPUserValueLargeResponseMode = "__bifrost_large_response_mode"
|
||||
)
|
||||
|
||||
// ParseSessionIDFromBaggage extracts the session-id baggage member value.
|
||||
// It supports simple W3C baggage parsing sufficient for log grouping.
|
||||
func ParseSessionIDFromBaggage(header string) string {
|
||||
for _, member := range strings.Split(header, ",") {
|
||||
member = strings.TrimSpace(member)
|
||||
if member == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(member, ";", 2)
|
||||
kv := strings.SplitN(strings.TrimSpace(parts[0]), "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.ToLower(strings.TrimSpace(kv[0]))
|
||||
value := strings.TrimSpace(kv[1])
|
||||
if key != "session-id" || value == "" {
|
||||
continue
|
||||
}
|
||||
if len(value) > 255 {
|
||||
if logger != nil {
|
||||
logger.Warn("session-id exceeds 255 chars, ignoring: length=%d, prefix=%s", len(value), value[:255])
|
||||
}
|
||||
continue
|
||||
}
|
||||
return value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context,
|
||||
// preserving important header values for monitoring and tracing purposes.
|
||||
//
|
||||
// The function processes several types of special headers:
|
||||
// 1. Prometheus Headers (x-bf-prom-*):
|
||||
// - All headers prefixed with 'x-bf-prom-' are copied to the context
|
||||
// - The prefix is stripped and the remainder becomes the context key
|
||||
// - Example: 'x-bf-prom-latency' becomes 'latency' in the context
|
||||
//
|
||||
// 2. Maxim Tracing Headers (x-bf-maxim-*):
|
||||
// - Specifically handles 'x-bf-maxim-traceID' and 'x-bf-maxim-generationID'
|
||||
// - These headers enable trace correlation across service boundaries
|
||||
// - Values are stored using Maxim's context keys for consistency
|
||||
//
|
||||
// 3. MCP Headers (x-bf-mcp-*):
|
||||
// - Specifically handles 'x-bf-mcp-include-clients' and 'x-bf-mcp-include-tools' (include-only filtering)
|
||||
// - These headers enable MCP client and tool filtering
|
||||
// - Values are stored using MCP context keys for consistency
|
||||
//
|
||||
// 4. Governance Headers:
|
||||
// - x-bf-vk: Virtual key for governance (required for governance to work)
|
||||
//
|
||||
// 5. API Key Headers:
|
||||
// - Authorization: Bearer token format only (e.g., "Bearer sk-...") - OpenAI style
|
||||
// - x-api-key: Direct API key value - Anthropic style
|
||||
// - x-goog-api-key: Direct API key value - Google Gemini style
|
||||
// - x-bf-api-key references a stored API key name rather than the raw secret.
|
||||
// - Keys are extracted and stored in the context using schemas.BifrostContextKey
|
||||
// - This enables explicit key usage for requests via headers
|
||||
//
|
||||
// 6. Cancellable Context:
|
||||
// - Creates a cancellable context that can be used to cancel upstream requests when clients disconnect
|
||||
// - This is critical for streaming requests where write errors indicate client disconnects
|
||||
// - Also useful for non-streaming requests to allow provider-level cancellation
|
||||
//
|
||||
// 7. Extra Headers (x-bf-eh-*):
|
||||
// - Any header starting with 'x-bf-eh-' is collected and added to the map stored under schemas.BifrostContextKeyExtraHeaders
|
||||
// - The prefix is stripped, the remainder is lower-cased, and duplicate names append values
|
||||
// - This allows callers to send arbitrary context metadata without needing to extend the public schema
|
||||
//
|
||||
// 8. Session Stickiness Headers:
|
||||
// - x-bf-session-id: Session identifier for key binding (reuse same key across requests)
|
||||
// - x-bf-session-ttl: Per-request TTL override (duration string e.g. "30m" or seconds integer)
|
||||
//
|
||||
// 9. Raw Capture Headers (per-request override of provider config; accepts "true" or "false"):
|
||||
// - x-bf-send-back-raw-request: include raw provider request in the BifrostResponse returned to the caller
|
||||
// - x-bf-send-back-raw-response: include raw provider response in the BifrostResponse returned to the caller
|
||||
// - x-bf-store-raw-request-response: capture raw request/response for logging only (stripped from client response)
|
||||
|
||||
// Parameters:
|
||||
// - ctx: The FastHTTP request context containing the original headers
|
||||
// - allowDirectKeys: Whether to allow direct API key usage from headers
|
||||
//
|
||||
// Returns:
|
||||
// - *context.Context: A new cancellable context.Context containing the propagated values
|
||||
// - context.CancelFunc: Function to cancel the context (should be called when request completes)
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// fastCtx := &fasthttp.RequestCtx{...}
|
||||
// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, true, nil)
|
||||
// defer cancel() // Ensure cleanup
|
||||
// // bifrostCtx now contains propagated header values including Prometheus metrics,
|
||||
// // Maxim tracing data, MCP filters, governance keys, API keys, cache settings,
|
||||
// // session stickiness, and extra headers
|
||||
|
||||
func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher, mcpHeaderCombinedAllowlist schemas.WhiteList) (*schemas.BifrostContext, context.CancelFunc) {
|
||||
// Reuse a shared request-scoped context when available.
|
||||
var bifrostCtx *schemas.BifrostContext
|
||||
var cancel context.CancelFunc
|
||||
if existing, ok := ctx.UserValue(FastHTTPUserValueBifrostContext).(*schemas.BifrostContext); ok && existing != nil {
|
||||
if existingCancel, ok := ctx.UserValue(FastHTTPUserValueBifrostCancel).(context.CancelFunc); ok && existingCancel != nil {
|
||||
bifrostCtx = existing
|
||||
cancel = existingCancel
|
||||
} else {
|
||||
// Create one cancellable child context and promote it as the shared context.
|
||||
bifrostCtx, cancel = schemas.NewBifrostContextWithCancel(existing)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostContext, bifrostCtx)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostCancel, cancel)
|
||||
}
|
||||
}
|
||||
if bifrostCtx == nil {
|
||||
// Create cancellable context for requests that don't have a shared context yet.
|
||||
parent := context.Context(ctx)
|
||||
func() {
|
||||
// Zero-value fasthttp.RequestCtx can panic on Done(); fall back safely.
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
}()
|
||||
_ = ctx.Done()
|
||||
}()
|
||||
bifrostCtx, cancel = schemas.NewBifrostContextWithCancel(parent)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostContext, bifrostCtx)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostCancel, cancel)
|
||||
}
|
||||
|
||||
// Preserve existing request-id if already present on the shared context.
|
||||
if existingRequestID, ok := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string); !ok || existingRequestID == "" {
|
||||
// First, check if x-request-id header exists
|
||||
requestID := string(ctx.Request.Header.Peek("x-request-id"))
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
|
||||
}
|
||||
// Populating all user values from the request context
|
||||
ctx.VisitUserValuesAll(func(key, value any) {
|
||||
bifrostCtx.SetValue(key, value)
|
||||
})
|
||||
// Initialize tags map for collecting maxim tags
|
||||
maximTags := make(map[string]string)
|
||||
// Initialize extra headers map for headers prefixed with x-bf-eh-
|
||||
extraHeaders := make(map[string][]string)
|
||||
// Initialize extra headers map for headers in the mcp header combined allowlist
|
||||
mcpExtraHeaders := make(map[string][]string)
|
||||
// Security denylist of header names that should never be accepted (case-insensitive)
|
||||
// This denylist is always enforced regardless of user configuration
|
||||
securityDenylist := map[string]bool{
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"host": true,
|
||||
"content-length": true,
|
||||
"connection": true,
|
||||
"transfer-encoding": true,
|
||||
|
||||
// prevent auth/key overrides via x-bf-eh-*
|
||||
"x-api-key": true,
|
||||
"x-goog-api-key": true,
|
||||
"x-bf-api-key": true,
|
||||
"x-bf-api-key-id": true,
|
||||
"x-bf-vk": true,
|
||||
}
|
||||
|
||||
// Debug: Log header matcher state
|
||||
if logger != nil {
|
||||
if matcher != nil {
|
||||
logger.Debug("headerMatcher hasAllowlist=%v, hasDenylist=%v", matcher.HasAllowlist(), matcher.hasDenylist)
|
||||
} else {
|
||||
logger.Debug("headerMatcher is nil (allow all)")
|
||||
}
|
||||
}
|
||||
|
||||
// Then process other headers
|
||||
ctx.Request.Header.All()(func(key, value []byte) bool {
|
||||
keyStr := strings.ToLower(string(key))
|
||||
if keyStr == "baggage" {
|
||||
if sessionID := ParseSessionIDFromBaggage(string(value)); sessionID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-prom-"); ok {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
return true
|
||||
}
|
||||
// Checking for maxim headers
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-maxim-"); ok {
|
||||
switch labelName {
|
||||
case string(maxim.GenerationIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.TraceIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.SessionIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.TraceNameKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.GenerationNameKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.LogRepoIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
default:
|
||||
// apart from these all headers starting with x-bf-maxim- are keys for tags
|
||||
// collect them in the maximTags map
|
||||
maximTags[labelName] = string(value)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// MCP control headers (include-only filtering)
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-mcp-"); ok {
|
||||
switch labelName {
|
||||
case "include-clients":
|
||||
fallthrough
|
||||
case "include-tools":
|
||||
// Parse comma-separated values into []string
|
||||
valueStr := string(value)
|
||||
var parsedValues []string
|
||||
if valueStr != "" {
|
||||
// Split by comma and trim whitespace
|
||||
for _, v := range strings.Split(valueStr, ",") {
|
||||
if trimmed := strings.TrimSpace(v); trimmed != "" {
|
||||
parsedValues = append(parsedValues, trimmed)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
parsedValues = []string{""}
|
||||
}
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey("mcp-"+labelName), parsedValues)
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Handle virtual key header (x-bf-vk, authorization, x-api-key, x-goog-api-key headers)
|
||||
if keyStr == string(schemas.BifrostContextKeyVirtualKey) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
|
||||
return true
|
||||
}
|
||||
if keyStr == "authorization" {
|
||||
valueStr := string(value)
|
||||
// Only accept Bearer token format: "Bearer ..."
|
||||
if strings.HasPrefix(strings.ToLower(valueStr), "bearer ") {
|
||||
authHeaderValue := strings.TrimSpace(valueStr[7:]) // Remove "Bearer " prefix
|
||||
if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, authHeaderValue)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if keyStr == "x-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-goog-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-api-key" {
|
||||
if keyName := strings.TrimSpace(string(value)); keyName != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, keyName)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-api-key-id" {
|
||||
if keyID := strings.TrimSpace(string(value)); keyID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, keyID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Handle cache key header (x-bf-cache-key)
|
||||
if keyStr == "x-bf-cache-key" {
|
||||
bifrostCtx.SetValue(semanticcache.CacheKey, string(value))
|
||||
return true
|
||||
}
|
||||
// Handle cache TTL header (x-bf-cache-ttl)
|
||||
if keyStr == "x-bf-cache-ttl" {
|
||||
valueStr := string(value)
|
||||
var ttlDuration time.Duration
|
||||
var err error
|
||||
|
||||
// First try to parse as duration (e.g., "30s", "5m", "1h")
|
||||
if ttlDuration, err = time.ParseDuration(valueStr); err != nil {
|
||||
// If that fails, try to parse as plain number and treat as seconds
|
||||
if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 {
|
||||
ttlDuration = time.Duration(seconds) * time.Second
|
||||
err = nil // Reset error since we successfully parsed as seconds
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
bifrostCtx.SetValue(semanticcache.CacheTTLKey, ttlDuration)
|
||||
}
|
||||
// If both parsing attempts fail, we silently ignore the header and use default TTL
|
||||
return true
|
||||
}
|
||||
// Cache threshold header
|
||||
if keyStr == "x-bf-cache-threshold" {
|
||||
threshold, err := strconv.ParseFloat(string(value), 64)
|
||||
if err == nil {
|
||||
// Clamp threshold to the inclusive range [0.0, 1.0]
|
||||
if threshold < 0.0 {
|
||||
threshold = 0.0
|
||||
} else if threshold > 1.0 {
|
||||
threshold = 1.0
|
||||
}
|
||||
bifrostCtx.SetValue(semanticcache.CacheThresholdKey, threshold)
|
||||
}
|
||||
// If parsing fails, silently ignore the header (no context value set)
|
||||
return true
|
||||
}
|
||||
// Cache type header
|
||||
if keyStr == "x-bf-cache-type" {
|
||||
bifrostCtx.SetValue(semanticcache.CacheTypeKey, semanticcache.CacheType(string(value)))
|
||||
return true
|
||||
}
|
||||
// Cache no store header
|
||||
if keyStr == "x-bf-cache-no-store" {
|
||||
if valueStr := string(value); valueStr == "true" {
|
||||
bifrostCtx.SetValue(semanticcache.CacheNoStoreKey, true)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Session stickiness: session ID for key binding
|
||||
if keyStr == "x-bf-session-id" {
|
||||
if valueStr := strings.TrimSpace(string(value)); valueStr != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySessionID, valueStr)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Session stickiness: per-request TTL override (duration string or seconds integer)
|
||||
if keyStr == "x-bf-session-ttl" {
|
||||
valueStr := strings.TrimSpace(string(value))
|
||||
var ttlDuration time.Duration
|
||||
var err error
|
||||
if ttlDuration, err = time.ParseDuration(valueStr); err != nil {
|
||||
if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 {
|
||||
ttlDuration = time.Duration(seconds) * time.Second
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err == nil && ttlDuration > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySessionTTL, ttlDuration)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-eh-"); ok {
|
||||
// Skip empty header names after prefix removal
|
||||
if labelName == "" {
|
||||
return true
|
||||
}
|
||||
// Normalize header name to lowercase
|
||||
labelName = strings.ToLower(labelName)
|
||||
// Validate against security denylist (always enforced)
|
||||
if securityDenylist[labelName] {
|
||||
return true
|
||||
}
|
||||
// Apply configurable header filter
|
||||
if !matcher.ShouldAllow(labelName) {
|
||||
return true
|
||||
}
|
||||
// Append header value (allow multiple values for the same header)
|
||||
extraHeaders[labelName] = append(extraHeaders[labelName], string(value))
|
||||
return true
|
||||
}
|
||||
// Direct header forwarding: when allowlist is configured, any header explicitly
|
||||
// in the allowlist can be forwarded directly without the x-bf-eh- prefix.
|
||||
// This enables forwarding arbitrary headers like "anthropic-beta" directly.
|
||||
// Only applies when allowlist is non-empty (backward compatible).
|
||||
if matcher.HasAllowlist() {
|
||||
if matcher.MatchesAllow(keyStr) {
|
||||
// Skip reserved x-bf-* headers (handled separately)
|
||||
if strings.HasPrefix(keyStr, "x-bf-") {
|
||||
return true
|
||||
}
|
||||
// Validate against security denylist (always enforced)
|
||||
if securityDenylist[keyStr] {
|
||||
return true
|
||||
}
|
||||
// Check denylist
|
||||
if matcher.MatchesDeny(keyStr) {
|
||||
return true
|
||||
}
|
||||
// Forward the header directly with its original name
|
||||
if logger != nil {
|
||||
logger.Debug("forwarding header via allowlist: %s", keyStr)
|
||||
}
|
||||
extraHeaders[keyStr] = append(extraHeaders[keyStr], string(value))
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Handle MCP extra headers
|
||||
if mcpHeaderCombinedAllowlist.IsAllowed(keyStr) {
|
||||
mcpExtraHeaders[keyStr] = append(mcpExtraHeaders[keyStr], string(value))
|
||||
return true
|
||||
}
|
||||
// Raw capture headers — all three support "true"/"false" to fully override the
|
||||
// provider-level config for this request.
|
||||
if keyStr == "x-bf-send-back-raw-request" {
|
||||
if b, err := strconv.ParseBool(string(value)); err == nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawRequest, b)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-send-back-raw-response" {
|
||||
if b, err := strconv.ParseBool(string(value)); err == nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawResponse, b)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-store-raw-request-response" {
|
||||
if b, err := strconv.ParseBool(string(value)); err == nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyStoreRawRequestResponse, b)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Parent request ID header (for linking MCP tool calls to parent LLM requests)
|
||||
if keyStr == "x-bf-parent-request-id" {
|
||||
if valueStr := strings.TrimSpace(string(value)); valueStr != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, valueStr)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Add passthrough extra params header support
|
||||
if keyStr == "x-bf-passthrough-extra-params" {
|
||||
if valueStr := string(value); valueStr == "true" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Compat header: per-request override of compat plugin settings.
|
||||
// Accepts: "true" (enable all), JSON array of feature names, or ["*"] (enable all).
|
||||
// An empty array [] or absent header means no overrides.
|
||||
if keyStr == "x-bf-compat" {
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertTextToChat)
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertChatToResponses)
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldDropParams)
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldConvertParams)
|
||||
valueStr := strings.TrimSpace(string(value))
|
||||
if valueStr == "true" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
|
||||
} else if strings.HasPrefix(valueStr, "[") {
|
||||
var features []string
|
||||
if err := json.Unmarshal([]byte(valueStr), &features); err == nil {
|
||||
if len(features) == 1 && features[0] == "*" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
|
||||
} else {
|
||||
for _, f := range features {
|
||||
switch f {
|
||||
case "convert_text_to_chat":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
|
||||
case "convert_chat_to_responses":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
|
||||
case "should_drop_params":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
|
||||
case "should_convert_params":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Store the collected maxim tags in the context
|
||||
if len(maximTags) > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(maxim.TagsKey), maximTags)
|
||||
}
|
||||
|
||||
// Store collected extra headers in the context if any were found
|
||||
if len(extraHeaders) > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
|
||||
}
|
||||
|
||||
// Store collected MCP extra headers in the context if any were found
|
||||
if len(mcpExtraHeaders) > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPExtraHeaders, mcpExtraHeaders)
|
||||
}
|
||||
|
||||
// Collect all request headers for downstream use (e.g., governance required headers check)
|
||||
// Keys are lowercased for case-insensitive lookup
|
||||
allHeaders := make(map[string]string)
|
||||
ctx.Request.Header.All()(func(key, value []byte) bool {
|
||||
allHeaders[strings.ToLower(string(key))] = string(value)
|
||||
return true
|
||||
})
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyRequestHeaders, allHeaders)
|
||||
|
||||
// Extract per-user MCP OAuth user identifier from X-Bf-User-Id header
|
||||
if mcpUserID := string(ctx.Request.Header.Peek("X-Bf-User-Id")); mcpUserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserID, mcpUserID)
|
||||
}
|
||||
|
||||
// Build and set OAuth redirect URI for per-user OAuth flows
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
if host != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyOAuthRedirectURI, fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host))
|
||||
}
|
||||
|
||||
if allowDirectKeys {
|
||||
// Extract API key from Authorization header (Bearer format), x-api-key, or x-goog-api-key header
|
||||
var apiKey string
|
||||
|
||||
// TODO: fix plugin data leak
|
||||
// Check Authorization header (Bearer format only - OpenAI style)
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
if authHeader != "" {
|
||||
// Only accept Bearer token format: "Bearer ..."
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix
|
||||
if authHeaderValue != "" && !strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) {
|
||||
apiKey = authHeaderValue
|
||||
}
|
||||
} else {
|
||||
apiKey = authHeader
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
// Check x-api-key (Anthropic style) header if no valid Authorization header found
|
||||
xAPIKey := string(ctx.Request.Header.Peek("x-api-key"))
|
||||
if xAPIKey != "" && !strings.HasPrefix(strings.ToLower(xAPIKey), governance.VirtualKeyPrefix) {
|
||||
apiKey = strings.TrimSpace(xAPIKey)
|
||||
} else {
|
||||
// Check x-goog-api-key (Google Gemini style) header if no valid Authorization header found
|
||||
xGoogleAPIKey := string(ctx.Request.Header.Peek("x-goog-api-key"))
|
||||
if xGoogleAPIKey != "" && !strings.HasPrefix(strings.ToLower(xGoogleAPIKey), governance.VirtualKeyPrefix) {
|
||||
apiKey = strings.TrimSpace(xGoogleAPIKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we found an API key, create a Key object and store it in context
|
||||
if apiKey != "" {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided", // Identifier for header-provided keys
|
||||
Value: *schemas.NewEnvVar(apiKey),
|
||||
Models: schemas.WhiteList{"*"}, // Allow all models
|
||||
Weight: 1.0, // Default weight
|
||||
}
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
return bifrostCtx, cancel
|
||||
}
|
||||
|
||||
// BuildHTTPRequestFromFastHTTP creates an HTTPRequest from fasthttp context for streaming handlers.
|
||||
// The returned request should be released with schemas.ReleaseHTTPRequest when done.
|
||||
// Note: Body is not copied for streaming (body was already consumed for the request).
|
||||
func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest {
|
||||
req := schemas.AcquireHTTPRequest()
|
||||
req.Method = string(ctx.Method())
|
||||
req.Path = string(ctx.Path())
|
||||
|
||||
// Copy headers
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
req.Headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
// Copy query params
|
||||
for key, value := range ctx.Request.URI().QueryArgs().All() {
|
||||
req.Query[string(key)] = string(value)
|
||||
}
|
||||
|
||||
// Copy path parameters from user values
|
||||
ctx.VisitUserValuesAll(func(key, value any) {
|
||||
keyStr, keyIsString := key.(string)
|
||||
valueStr, valueIsString := value.(string)
|
||||
if !keyIsString || !valueIsString {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(keyStr, "bifrost-") ||
|
||||
keyStr == "BifrostContextKeyRequestID" ||
|
||||
keyStr == "trace_id" ||
|
||||
keyStr == "span_id" {
|
||||
return
|
||||
}
|
||||
req.PathParams[keyStr] = valueStr
|
||||
})
|
||||
|
||||
// Note: Body not copied - for streaming, body was already consumed
|
||||
return req
|
||||
}
|
||||
|
||||
// BuildHTTPResponseFromFastHTTP creates an HTTPResponse snapshot from fasthttp context.
|
||||
// Only captures status code and headers — body is skipped because for streaming
|
||||
// responses it is an active io.Reader that cannot be materialized.
|
||||
// The returned response should be released with schemas.ReleaseHTTPResponse when done.
|
||||
func BuildHTTPResponseFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPResponse {
|
||||
resp := schemas.AcquireHTTPResponse()
|
||||
resp.StatusCode = ctx.Response.StatusCode()
|
||||
for key, value := range ctx.Response.Header.All() {
|
||||
resp.Headers[string(key)] = string(value)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
275
transports/bifrost-http/lib/ctx_test.go
Normal file
275
transports/bifrost-http/lib/ctx_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestParseSessionIDFromBaggage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{name: "single member", header: "session-id=abc", want: "abc"},
|
||||
{name: "multiple members", header: "foo=bar, session-id=abc, baz=qux", want: "abc"},
|
||||
{name: "member with properties", header: "session-id=abc;ttl=60", want: "abc"},
|
||||
{name: "spaces preserved around parsing", header: " foo=bar , session-id = abc123 ;ttl=60 ", want: "abc123"},
|
||||
{name: "missing member", header: "foo=bar", want: ""},
|
||||
{name: "malformed ignored", header: "session-id, foo=bar", want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ParseSessionIDFromBaggage(tt.header); got != tt.want {
|
||||
t.Fatalf("ParseSessionIDFromBaggage(%q) = %q, want %q", tt.header, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
base := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
base.SetValue(schemas.BifrostContextKeyRequestID, "req-shared")
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostContext, base)
|
||||
|
||||
converted, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
if converted == nil {
|
||||
t.Fatal("expected non-nil converted context")
|
||||
}
|
||||
if got, _ := converted.Value(schemas.BifrostContextKeyRequestID).(string); got != "req-shared" {
|
||||
t.Fatalf("expected converted context to preserve parent values, got request-id=%q", got)
|
||||
}
|
||||
if stored, ok := ctx.UserValue(FastHTTPUserValueBifrostContext).(*schemas.BifrostContext); !ok || stored == nil {
|
||||
t.Fatal("expected shared context pointer to be stored on fasthttp user values")
|
||||
}
|
||||
if ctx.UserValue(FastHTTPUserValueBifrostCancel) == nil {
|
||||
t.Fatal("expected shared cancel function to be stored on fasthttp user values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_SecondCallReturnsSameSharedContext(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
first, cancelFirst := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancelFirst()
|
||||
if first == nil {
|
||||
t.Fatal("expected first context to be non-nil")
|
||||
}
|
||||
|
||||
second, cancelSecond := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancelSecond()
|
||||
if second == nil {
|
||||
t.Fatal("expected second context to be non-nil")
|
||||
}
|
||||
if first != second {
|
||||
t.Fatal("expected ConvertToBifrostContext to reuse the shared context on repeated calls")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked verifies that
|
||||
// even with a "*" allowlist (allow all), the hardcoded security denylist in
|
||||
// ConvertToBifrostContext still blocks security-sensitive headers.
|
||||
func TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// x-bf-eh-* prefixed headers
|
||||
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-cookie", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-x-api-key", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-host", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-connection", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-proxy-authorization", "should-be-blocked")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
// custom-header should be forwarded
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded via x-bf-eh- prefix")
|
||||
}
|
||||
|
||||
// Security headers should be blocked even with * allowlist
|
||||
securityHeaders := []string{"cookie", "x-api-key", "host", "connection", "proxy-authorization"}
|
||||
for _, h := range securityHeaders {
|
||||
if _, ok := extraHeaders[h]; ok {
|
||||
t.Errorf("expected security header %q to be blocked even with * allowlist", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked verifies
|
||||
// that direct header forwarding with "*" allowlist forwards non-security headers
|
||||
// but still blocks security headers.
|
||||
func TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// Direct headers (not prefixed with x-bf-eh-)
|
||||
ctx.Request.Header.Set("custom-header", "allowed-value")
|
||||
ctx.Request.Header.Set("anthropic-beta", "some-beta-feature")
|
||||
// Security headers sent directly — should be blocked
|
||||
ctx.Request.Header.Set("proxy-authorization", "should-be-blocked")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
// Direct non-security headers should be forwarded when allowlist has *
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded directly")
|
||||
}
|
||||
if _, ok := extraHeaders["anthropic-beta"]; !ok {
|
||||
t.Error("expected anthropic-beta to be forwarded directly")
|
||||
}
|
||||
|
||||
// Security headers should still be blocked in direct forwarding path
|
||||
directSecurityHeaders := []string{"proxy-authorization", "cookie", "host", "connection"}
|
||||
for _, h := range directSecurityHeaders {
|
||||
if _, ok := extraHeaders[h]; ok {
|
||||
t.Errorf("expected security header %q to be blocked in direct forwarding even with * allowlist", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_PrefixWildcardDirectForwarding verifies that
|
||||
// prefix wildcard patterns like "anthropic-*" work for direct header forwarding
|
||||
// (without x-bf-eh- prefix).
|
||||
func TestConvertToBifrostContext_PrefixWildcardDirectForwarding(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// Direct headers matching the wildcard pattern
|
||||
ctx.Request.Header.Set("anthropic-beta", "beta-value")
|
||||
ctx.Request.Header.Set("anthropic-version", "2024-01-01")
|
||||
// Header not matching the pattern
|
||||
ctx.Request.Header.Set("openai-version", "should-not-forward")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["anthropic-beta"]; !ok {
|
||||
t.Error("expected anthropic-beta to be forwarded directly via wildcard allowlist")
|
||||
}
|
||||
if _, ok := extraHeaders["anthropic-version"]; !ok {
|
||||
t.Error("expected anthropic-version to be forwarded directly via wildcard allowlist")
|
||||
}
|
||||
if _, ok := extraHeaders["openai-version"]; ok {
|
||||
t.Error("expected openai-version to NOT be forwarded (doesn't match anthropic-*)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_WildcardAllowlistFiltering verifies wildcard patterns
|
||||
// correctly filter headers via the x-bf-eh- prefix path.
|
||||
func TestConvertToBifrostContext_WildcardAllowlistFiltering(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("x-bf-eh-anthropic-beta", "beta-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-anthropic-version", "2024-01-01")
|
||||
ctx.Request.Header.Set("x-bf-eh-openai-version", "should-be-blocked")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["anthropic-beta"]; !ok {
|
||||
t.Error("expected anthropic-beta to be forwarded")
|
||||
}
|
||||
if _, ok := extraHeaders["anthropic-version"]; !ok {
|
||||
t.Error("expected anthropic-version to be forwarded")
|
||||
}
|
||||
if _, ok := extraHeaders["openai-version"]; ok {
|
||||
t.Error("expected openai-version to be blocked (not matching anthropic-*)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_WildcardDenylistBlocking verifies wildcard denylist
|
||||
// patterns block matching headers.
|
||||
func TestConvertToBifrostContext_WildcardDenylistBlocking(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-internal-*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("x-bf-eh-x-internal-id", "blocked-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-x-internal-secret", "blocked-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["x-internal-id"]; ok {
|
||||
t.Error("expected x-internal-id to be blocked by denylist")
|
||||
}
|
||||
if _, ok := extraHeaders["x-internal-secret"]; ok {
|
||||
t.Error("expected x-internal-secret to be blocked by denylist")
|
||||
}
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_NilMatcher verifies nil matcher allows all headers.
|
||||
func TestConvertToBifrostContext_NilMatcher(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded with nil matcher")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_BaggageSessionIDSetsGrouping(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("baggage", "foo=bar, session-id=rt-123, baz=qux")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
if got, _ := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-123" {
|
||||
t.Fatalf("parent request id = %q, want %q", got, "rt-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_EmptyBaggageSessionIDIgnored(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("baggage", "session-id= ")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID); got != nil {
|
||||
t.Fatalf("parent request id should be unset, got %#v", got)
|
||||
}
|
||||
}
|
||||
6
transports/bifrost-http/lib/errors.go
Normal file
6
transports/bifrost-http/lib/errors.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package lib
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrAlreadyExists = errors.New("already exists")
|
||||
136
transports/bifrost-http/lib/headermatcher.go
Normal file
136
transports/bifrost-http/lib/headermatcher.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// HeaderMatchesPattern returns true if headerName matches the pattern.
|
||||
// Patterns support trailing wildcard: "anthropic-*" matches "anthropic-beta".
|
||||
// A bare "*" matches everything. All comparisons are case-insensitive.
|
||||
func HeaderMatchesPattern(pattern, headerName string) bool {
|
||||
pattern = strings.ToLower(strings.TrimSpace(pattern))
|
||||
headerName = strings.ToLower(strings.TrimSpace(headerName))
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
return strings.HasPrefix(headerName, pattern[:len(pattern)-1])
|
||||
}
|
||||
return pattern == headerName
|
||||
}
|
||||
|
||||
// HeaderMatcher holds precomputed header filter data for O(1) exact-match lookups
|
||||
// and fast prefix matching. Compiled once on config change, safe for concurrent reads.
|
||||
type HeaderMatcher struct {
|
||||
allowExact map[string]bool
|
||||
allowPrefixes []string // lowercased prefixes (without trailing *)
|
||||
allowAll bool
|
||||
hasAllowlist bool
|
||||
denyExact map[string]bool
|
||||
denyPrefixes []string
|
||||
denyAll bool
|
||||
hasDenylist bool
|
||||
}
|
||||
|
||||
// NewHeaderMatcher compiles a GlobalHeaderFilterConfig into an optimized HeaderMatcher.
|
||||
// Returns nil if config is nil (callers should treat nil as "allow all").
|
||||
func NewHeaderMatcher(config *configstoreTables.GlobalHeaderFilterConfig) *HeaderMatcher {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
m := &HeaderMatcher{
|
||||
allowExact: make(map[string]bool, len(config.Allowlist)),
|
||||
denyExact: make(map[string]bool, len(config.Denylist)),
|
||||
}
|
||||
for _, p := range config.Allowlist {
|
||||
lp := strings.ToLower(strings.TrimSpace(p))
|
||||
if lp == "" {
|
||||
continue
|
||||
}
|
||||
if lp == "*" {
|
||||
m.allowAll = true
|
||||
} else if strings.HasSuffix(lp, "*") {
|
||||
m.allowPrefixes = append(m.allowPrefixes, lp[:len(lp)-1])
|
||||
} else {
|
||||
m.allowExact[lp] = true
|
||||
}
|
||||
}
|
||||
for _, p := range config.Denylist {
|
||||
lp := strings.ToLower(strings.TrimSpace(p))
|
||||
if lp == "" {
|
||||
continue
|
||||
}
|
||||
if lp == "*" {
|
||||
m.denyAll = true
|
||||
} else if strings.HasSuffix(lp, "*") {
|
||||
m.denyPrefixes = append(m.denyPrefixes, lp[:len(lp)-1])
|
||||
} else {
|
||||
m.denyExact[lp] = true
|
||||
}
|
||||
}
|
||||
m.hasAllowlist = m.allowAll || len(m.allowExact) > 0 || len(m.allowPrefixes) > 0
|
||||
m.hasDenylist = m.denyAll || len(m.denyExact) > 0 || len(m.denyPrefixes) > 0
|
||||
return m
|
||||
}
|
||||
|
||||
// HasAllowlist returns true if the matcher has a non-empty allowlist.
|
||||
func (m *HeaderMatcher) HasAllowlist() bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
return m.hasAllowlist
|
||||
}
|
||||
|
||||
// MatchesAllow returns true if headerName matches any allowlist entry.
|
||||
// headerName must be lowercased by the caller.
|
||||
func (m *HeaderMatcher) MatchesAllow(headerName string) bool {
|
||||
if m.allowAll {
|
||||
return true
|
||||
}
|
||||
if m.allowExact[headerName] {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range m.allowPrefixes {
|
||||
if strings.HasPrefix(headerName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MatchesDeny returns true if headerName matches any denylist entry.
|
||||
// headerName must be lowercased by the caller.
|
||||
func (m *HeaderMatcher) MatchesDeny(headerName string) bool {
|
||||
if m.denyAll {
|
||||
return true
|
||||
}
|
||||
if m.denyExact[headerName] {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range m.denyPrefixes {
|
||||
if strings.HasPrefix(headerName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ShouldAllow determines if a header should be forwarded based on the
|
||||
// configurable header filter config (separate from the security denylist).
|
||||
// Returns true if the header passes both allowlist and denylist checks.
|
||||
// headerName is lowercased internally for case-insensitive matching.
|
||||
func (m *HeaderMatcher) ShouldAllow(headerName string) bool {
|
||||
if m == nil {
|
||||
return true
|
||||
}
|
||||
headerName = strings.ToLower(headerName)
|
||||
if m.hasAllowlist && !m.MatchesAllow(headerName) {
|
||||
return false
|
||||
}
|
||||
if m.hasDenylist && m.MatchesDeny(headerName) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
251
transports/bifrost-http/lib/headermatcher_test.go
Normal file
251
transports/bifrost-http/lib/headermatcher_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
func TestHeaderMatchesPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
pattern string
|
||||
headerName string
|
||||
want bool
|
||||
}{
|
||||
// Exact match
|
||||
{"anthropic-beta", "anthropic-beta", true},
|
||||
{"anthropic-beta", "anthropic-alpha", false},
|
||||
|
||||
// Case insensitive exact match
|
||||
{"Anthropic-Beta", "anthropic-beta", true},
|
||||
{"anthropic-beta", "Anthropic-Beta", true},
|
||||
|
||||
// Star matches all
|
||||
{"*", "anything", true},
|
||||
{"*", "", true},
|
||||
|
||||
// Prefix wildcard
|
||||
{"anthropic-*", "anthropic-beta", true},
|
||||
{"anthropic-*", "anthropic-version", true},
|
||||
{"anthropic-*", "anthropic-", true},
|
||||
{"anthropic-*", "openai-version", false},
|
||||
{"anthropic-*", "anthropic", false},
|
||||
|
||||
// Case insensitive prefix wildcard
|
||||
{"Anthropic-*", "anthropic-beta", true},
|
||||
{"anthropic-*", "Anthropic-Beta", true},
|
||||
|
||||
// No match
|
||||
{"foo", "bar", false},
|
||||
{"", "foo", false},
|
||||
|
||||
// Pattern without wildcard doesn't prefix match
|
||||
{"anthropic-", "anthropic-beta", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern+"_"+tt.headerName, func(t *testing.T) {
|
||||
got := HeaderMatchesPattern(tt.pattern, tt.headerName)
|
||||
if got != tt.want {
|
||||
t.Errorf("HeaderMatchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.headerName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHeaderMatcher_Nil(t *testing.T) {
|
||||
m := NewHeaderMatcher(nil)
|
||||
if m != nil {
|
||||
t.Fatal("expected nil matcher for nil config")
|
||||
}
|
||||
// nil matcher should allow everything
|
||||
if !m.ShouldAllow("anything") {
|
||||
t.Error("nil matcher should allow all headers")
|
||||
}
|
||||
if m.HasAllowlist() {
|
||||
t.Error("nil matcher should have no allowlist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHeaderMatcher_Empty(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{})
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil matcher for empty config")
|
||||
}
|
||||
if m.HasAllowlist() {
|
||||
t.Error("empty config should have no allowlist")
|
||||
}
|
||||
if !m.ShouldAllow("anything") {
|
||||
t.Error("empty config should allow all headers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_ExactAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-beta", "custom-id"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if !m.ShouldAllow("custom-id") {
|
||||
t.Error("should allow custom-id")
|
||||
}
|
||||
if m.ShouldAllow("openai-version") {
|
||||
t.Error("should not allow openai-version")
|
||||
}
|
||||
// Case insensitive
|
||||
if !m.ShouldAllow("Anthropic-Beta") {
|
||||
t.Error("should allow Anthropic-Beta (case insensitive)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_WildcardAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if !m.ShouldAllow("anthropic-version") {
|
||||
t.Error("should allow anthropic-version")
|
||||
}
|
||||
if m.ShouldAllow("openai-version") {
|
||||
t.Error("should not allow openai-version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_StarAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
})
|
||||
if !m.ShouldAllow("anything") {
|
||||
t.Error("* should allow anything")
|
||||
}
|
||||
if !m.ShouldAllow("") {
|
||||
t.Error("* should allow empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_ExactDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"secret-token"},
|
||||
})
|
||||
if m.ShouldAllow("secret-token") {
|
||||
t.Error("should deny secret-token")
|
||||
}
|
||||
if !m.ShouldAllow("public-key") {
|
||||
t.Error("should allow public-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_WildcardDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-internal-*"},
|
||||
})
|
||||
if m.ShouldAllow("x-internal-id") {
|
||||
t.Error("should deny x-internal-id")
|
||||
}
|
||||
if m.ShouldAllow("x-internal-secret") {
|
||||
t.Error("should deny x-internal-secret")
|
||||
}
|
||||
if !m.ShouldAllow("x-external-id") {
|
||||
t.Error("should allow x-external-id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_StarDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"*"},
|
||||
})
|
||||
if m.ShouldAllow("anything") {
|
||||
t.Error("* denylist should deny everything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_AllowlistWithDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
Denylist: []string{"x-internal-*"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if m.ShouldAllow("x-internal-id") {
|
||||
t.Error("should deny x-internal-id (denylist overrides)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_AllowlistPrefixWithDenylistExact(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
Denylist: []string{"anthropic-dangerous"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if m.ShouldAllow("anthropic-dangerous") {
|
||||
t.Error("should deny anthropic-dangerous")
|
||||
}
|
||||
if m.ShouldAllow("openai-version") {
|
||||
t.Error("should not allow openai-version (not in allowlist)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_CaseInsensitive(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"Anthropic-*"},
|
||||
Denylist: []string{"X-Internal-*"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta (case insensitive)")
|
||||
}
|
||||
if m.ShouldAllow("x-internal-id") {
|
||||
t.Error("should deny x-internal-id (case insensitive)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_MatchesAllow(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*", "custom-id"},
|
||||
})
|
||||
if !m.MatchesAllow("anthropic-beta") {
|
||||
t.Error("should match anthropic-beta")
|
||||
}
|
||||
if !m.MatchesAllow("custom-id") {
|
||||
t.Error("should match custom-id")
|
||||
}
|
||||
if m.MatchesAllow("openai-version") {
|
||||
t.Error("should not match openai-version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_MatchesDeny(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"secret-*", "blocked"},
|
||||
})
|
||||
if !m.MatchesDeny("secret-token") {
|
||||
t.Error("should match secret-token")
|
||||
}
|
||||
if !m.MatchesDeny("blocked") {
|
||||
t.Error("should match blocked")
|
||||
}
|
||||
if m.MatchesDeny("allowed") {
|
||||
t.Error("should not match allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_HasAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"foo"},
|
||||
})
|
||||
if !m.HasAllowlist() {
|
||||
t.Error("should have allowlist")
|
||||
}
|
||||
|
||||
m2 := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"bar"},
|
||||
})
|
||||
if m2.HasAllowlist() {
|
||||
t.Error("should not have allowlist")
|
||||
}
|
||||
}
|
||||
58
transports/bifrost-http/lib/lib.go
Normal file
58
transports/bifrost-http/lib/lib.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var logger schemas.Logger
|
||||
|
||||
// SetLogger sets the logger for the application.
|
||||
func SetLogger(l schemas.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
// StreamLargeResponseBody extracts the large response reader from context and streams
|
||||
// it directly to the client. Sets status 200, content-type, and content-length headers.
|
||||
// Returns false if the reader is not available (caller should send an error response).
|
||||
func StreamLargeResponseBody(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
if bifrostCtx == nil {
|
||||
return false
|
||||
}
|
||||
reader, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseReader).(io.ReadCloser)
|
||||
if !ok || reader == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
contentLength, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentLength).(int)
|
||||
contentType, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentType).(string)
|
||||
contentDisposition, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentDisposition).(string)
|
||||
|
||||
// Mirror large-response-mode to fasthttp UserValue so post-hook middleware
|
||||
// (which only sees ctx.UserValue, not bifrostCtx) can skip body materialization.
|
||||
ctx.SetUserValue(FastHTTPUserValueLargeResponseMode, true)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
if contentType != "" {
|
||||
ctx.SetContentType(contentType)
|
||||
} else {
|
||||
ctx.SetContentType("application/json")
|
||||
}
|
||||
if contentDisposition != "" {
|
||||
ctx.Response.Header.Set("Content-Disposition", contentDisposition)
|
||||
}
|
||||
// bodySize for SetBodyStream: positive = known size, -1 = unknown (read until EOF).
|
||||
// fasthttp treats 0 as "known empty", so default to -1 when CL is unavailable.
|
||||
bodySize := contentLength
|
||||
if bodySize > 0 {
|
||||
ctx.Response.Header.Set("Content-Length", strconv.Itoa(contentLength))
|
||||
} else {
|
||||
bodySize = -1
|
||||
}
|
||||
|
||||
ctx.Response.SetBodyStream(reader, bodySize)
|
||||
return true
|
||||
}
|
||||
23
transports/bifrost-http/lib/middleware.go
Normal file
23
transports/bifrost-http/lib/middleware.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ChainMiddlewares chains multiple middlewares together
|
||||
// Middlewares are applied in order: the first middleware wraps the second, etc.
|
||||
// This allows earlier middlewares to short-circuit by not calling next(ctx)
|
||||
func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...schemas.BifrostHTTPMiddleware) fasthttp.RequestHandler {
|
||||
// If no middlewares, return the original handler
|
||||
if len(middlewares) == 0 {
|
||||
return handler
|
||||
}
|
||||
// Build the chain from right to left (last middleware wraps the handler)
|
||||
// This ensures execution order is left to right (first middleware executes first)
|
||||
chained := handler
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
chained = middlewares[i](chained)
|
||||
}
|
||||
return chained
|
||||
}
|
||||
1077
transports/bifrost-http/lib/pricing_integration_test.go
Normal file
1077
transports/bifrost-http/lib/pricing_integration_test.go
Normal file
File diff suppressed because it is too large
Load Diff
174
transports/bifrost-http/lib/semantic_cache_config_test.go
Normal file
174
transports/bifrost-http/lib/semantic_cache_config_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 1,
|
||||
"ttl": "5m",
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
configMap, ok := pluginConfig.Config.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
_, hasKeys := configMap["keys"]
|
||||
require.False(t, hasKeys, "direct-only mode should not inject provider keys")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 1,
|
||||
"keys": []schemas.Key{{Name: "stale-key"}},
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
configMap, ok := pluginConfig.Config.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
_, hasKeys := configMap["keys"]
|
||||
require.False(t, hasKeys, "direct-only mode should remove stale provider keys")
|
||||
_, hasEmbeddingModel := configMap["embedding_model"]
|
||||
require.False(t, hasEmbeddingModel, "direct-only mode should remove stale embedding_model")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) {
|
||||
config := &Config{
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
schemas.OpenAI: {
|
||||
Keys: []schemas.Key{
|
||||
{
|
||||
Name: "openai-key",
|
||||
Value: *schemas.NewEnvVar("sk-test"),
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"dimension": 1536,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
configMap, ok := pluginConfig.Config.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
keys, ok := configMap["keys"].([]schemas.Key)
|
||||
require.True(t, ok, "provider-backed mode should inject provider keys")
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, "openai-key", keys[0].Name)
|
||||
require.Equal(t, "openai", configMap["provider"])
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 1536,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'provider' for semantic mode")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'dimension' for provider-backed semantic mode")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"dimension": 1,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'dimension' > 1")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"dimension": 1536,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'embedding_model'")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 0,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "'dimension' must be >= 1")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": -1,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "'dimension' must be >= 1")
|
||||
}
|
||||
114
transports/bifrost-http/lib/streamreader.go
Normal file
114
transports/bifrost-http/lib/streamreader.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SSEStreamReader is an io.ReadCloser that delivers one event per Read call,
|
||||
// bypassing fasthttp's internal pipe mechanism (fasthttputil.PipeConns) which
|
||||
// batches multiple events into single TCP segments.
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create with NewSSEStreamReader()
|
||||
// 2. Pass to ctx.Response.SetBodyStream(reader, -1)
|
||||
// 3. Start a producer goroutine that calls Send()/SendEvent()/SendError() for each event
|
||||
// 4. Producer calls Done() when finished (closes the event channel)
|
||||
// 5. fasthttp calls Close() on write errors (signals producer to stop)
|
||||
type SSEStreamReader struct {
|
||||
eventCh chan []byte
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
current []byte // remaining bytes from a partial read
|
||||
}
|
||||
|
||||
// NewSSEStreamReader creates a new SSEStreamReader with a buffered event channel.
|
||||
// Channel capacity of 1 allows one event of pipeline parallelism between
|
||||
// the producer goroutine and fasthttp's writeBodyChunked loop.
|
||||
func NewSSEStreamReader() *SSEStreamReader {
|
||||
return &SSEStreamReader{
|
||||
eventCh: make(chan []byte, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It blocks until an event is available, then returns
|
||||
// that event's bytes. If the caller's buffer is smaller than the event, remaining
|
||||
// bytes are stored and returned on subsequent calls. Returns io.EOF when Done()
|
||||
// has been called and all events have been consumed.
|
||||
func (r *SSEStreamReader) Read(p []byte) (int, error) {
|
||||
if len(r.current) == 0 {
|
||||
event, ok := <-r.eventCh
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
r.current = event
|
||||
}
|
||||
n := copy(p, r.current)
|
||||
r.current = r.current[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close implements io.Closer. Called by fasthttp when writeBodyChunked encounters
|
||||
// a write error (client disconnect). Signals the producer goroutine to stop via closeCh.
|
||||
// Safe to call multiple times.
|
||||
func (r *SSEStreamReader) Close() error {
|
||||
r.closeOnce.Do(func() {
|
||||
close(r.closeCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send delivers a pre-formatted event to the reader. Returns false if the reader
|
||||
// has been closed (client disconnected), in which case the producer should stop.
|
||||
func (r *SSEStreamReader) Send(event []byte) bool {
|
||||
// Check closeCh first (non-blocking) to avoid sending after Close
|
||||
select {
|
||||
case <-r.closeCh:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case r.eventCh <- event:
|
||||
return true
|
||||
case <-r.closeCh:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SendEvent sends an SSE-framed event. If eventType is empty, it sends "data: <data>\n\n".
|
||||
// If eventType is non-empty, it sends "event: <eventType>\ndata: <data>\n\n".
|
||||
// Returns false if the reader has been closed (client disconnected).
|
||||
func (r *SSEStreamReader) SendEvent(eventType string, data []byte) bool {
|
||||
var buf []byte
|
||||
if eventType != "" {
|
||||
buf = make([]byte, 0, 7+len(eventType)+7+len(data)+2)
|
||||
buf = append(buf, "event: "...)
|
||||
buf = append(buf, eventType...)
|
||||
buf = append(buf, "\ndata: "...)
|
||||
} else {
|
||||
buf = make([]byte, 0, 6+len(data)+2)
|
||||
buf = append(buf, "data: "...)
|
||||
}
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return r.Send(buf)
|
||||
}
|
||||
|
||||
// SendError sends an SSE error event: "event: error\ndata: <data>\n\n".
|
||||
// Returns false if the reader has been closed (client disconnected).
|
||||
func (r *SSEStreamReader) SendError(data []byte) bool {
|
||||
return r.SendEvent("error", data)
|
||||
}
|
||||
|
||||
// SendDone sends the standard SSE done marker: "data: [DONE]\n\n".
|
||||
// Returns false if the reader has been closed (client disconnected).
|
||||
func (r *SSEStreamReader) SendDone() bool {
|
||||
return r.Send([]byte("data: [DONE]\n\n"))
|
||||
}
|
||||
|
||||
// Done closes the event channel, signaling to Read that the stream is finished.
|
||||
// Must be called exactly once by the producer goroutine when streaming is complete.
|
||||
func (r *SSEStreamReader) Done() {
|
||||
close(r.eventCh)
|
||||
}
|
||||
714
transports/bifrost-http/lib/streamreader_test.go
Normal file
714
transports/bifrost-http/lib/streamreader_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSSEStreamReaderSingleEventPerRead(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
events := [][]byte{
|
||||
[]byte("data: {\"chunk\":1}\n\n"),
|
||||
[]byte("data: {\"chunk\":2}\n\n"),
|
||||
[]byte("data: {\"chunk\":3}\n\n"),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
for _, e := range events {
|
||||
if !r.Send(e) {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("Send returned false unexpectedly"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for i, want := range events {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
got := string(buf[:n])
|
||||
if got != string(want) {
|
||||
t.Errorf("event %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Next read should return EOF
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Error(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderPartialRead(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
event := []byte("data: {\"content\":\"hello world\"}\n\n")
|
||||
|
||||
go func() {
|
||||
r.Send(event)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Read with a small buffer (5 bytes at a time)
|
||||
var result []byte
|
||||
buf := make([]byte, 5)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
result = append(result, buf[:n]...)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(result) != string(event) {
|
||||
t.Errorf("reassembled data: got %q, want %q", result, event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderEOFOnDone(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
r.Done() // Close immediately
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderCloseSignalsProducer(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
r.Close()
|
||||
|
||||
if r.Send([]byte("data: test\n\n")) {
|
||||
t.Error("Send should return false after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderIdempotentClose(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
// Should not panic
|
||||
r.Close()
|
||||
r.Close()
|
||||
r.Close()
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderConcurrent(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
const numEvents = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
// Producer
|
||||
go func() {
|
||||
for i := 0; i < numEvents; i++ {
|
||||
if !r.Send([]byte("data: event\n\n")) {
|
||||
break
|
||||
}
|
||||
}
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Consumer
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 4096)
|
||||
count := 0
|
||||
for {
|
||||
_, err := r.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("unexpected error: %v", err):
|
||||
default:
|
||||
}
|
||||
break
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != numEvents {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("got %d events, want %d", count, numEvents):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType string
|
||||
data []byte
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "data only",
|
||||
eventType: "",
|
||||
data: []byte(`{"chunk":1}`),
|
||||
want: "data: {\"chunk\":1}\n\n",
|
||||
},
|
||||
{
|
||||
name: "with event type",
|
||||
eventType: "response.delta",
|
||||
data: []byte(`{"delta":"hi"}`),
|
||||
want: "event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent(tt.eventType, tt.data)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got := string(buf[:n]); got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendError(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendError([]byte(`{"error":"bad"}`))
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := "event: error\ndata: {\"error\":\"bad\"}\n\n"
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendDone(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendDone()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := "data: [DONE]\n\n"
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendEventAfterClose(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
r.Close()
|
||||
|
||||
if r.SendEvent("test", []byte("data")) {
|
||||
t.Error("SendEvent should return false after Close")
|
||||
}
|
||||
if r.SendError([]byte("err")) {
|
||||
t.Error("SendError should return false after Close")
|
||||
}
|
||||
if r.SendDone() {
|
||||
t.Error("SendDone should return false after Close")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventByteAccuracy verifies that SendEvent produces
|
||||
// the exact same bytes that the old manual buffer assembly in the handlers did.
|
||||
func TestSSEStreamReaderSendEventByteAccuracy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType string
|
||||
data []byte
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "standard SSE data (old inference.go pattern)",
|
||||
eventType: "",
|
||||
data: []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`),
|
||||
want: func() []byte {
|
||||
// Old code: buf = append(buf, "data: "...); buf = append(buf, chunkJSON...); buf = append(buf, '\n', '\n')
|
||||
data := []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`)
|
||||
buf := make([]byte, 0, len(data)+8)
|
||||
buf = append(buf, "data: "...)
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return buf
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "OpenAI responses format with event type (old inference.go pattern)",
|
||||
eventType: "response.output_item.added",
|
||||
data: []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`),
|
||||
want: func() []byte {
|
||||
// Old code: buf = append(buf, "event: "...); buf = append(buf, eventType...); buf = append(buf, "\ndata: "...); ...
|
||||
eventType := "response.output_item.added"
|
||||
data := []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`)
|
||||
buf := make([]byte, 0, len(eventType)+len(data)+16)
|
||||
buf = append(buf, "event: "...)
|
||||
buf = append(buf, eventType...)
|
||||
buf = append(buf, "\ndata: "...)
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return buf
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "error event (old interceptor pattern)",
|
||||
eventType: "error",
|
||||
data: []byte(`{"error":"stream interrupted"}`),
|
||||
want: func() []byte {
|
||||
data := []byte(`{"error":"stream interrupted"}`)
|
||||
buf := make([]byte, 0, len(data)+24)
|
||||
buf = append(buf, "event: error\ndata: "...)
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return buf
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent(tt.eventType, tt.data)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
got := buf[:n]
|
||||
if string(got) != string(tt.want) {
|
||||
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendErrorByteAccuracy verifies SendError matches
|
||||
// the old "event: error\ndata: ..." manual assembly.
|
||||
func TestSSEStreamReaderSendErrorByteAccuracy(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
errorJSON := []byte(`{"error":{"type":"internal_error","message":"An error occurred"}}`)
|
||||
|
||||
go func() {
|
||||
r.SendError(errorJSON)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Must match the old pattern exactly:
|
||||
// buf = append(buf, "event: error\ndata: "...)
|
||||
// buf = append(buf, errorJSON...)
|
||||
// buf = append(buf, '\n', '\n')
|
||||
want := "event: error\ndata: " + string(errorJSON) + "\n\n"
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderMixedMethodStream simulates a realistic stream
|
||||
// that uses multiple methods (like router.go does): data events,
|
||||
// typed events, error, and done marker.
|
||||
func TestSSEStreamReaderMixedMethodStream(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
expected := []string{
|
||||
"data: {\"chunk\":1}\n\n",
|
||||
"event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
|
||||
"data: {\"chunk\":2}\n\n",
|
||||
"event: error\ndata: {\"error\":\"timeout\"}\n\n",
|
||||
"data: [DONE]\n\n",
|
||||
}
|
||||
|
||||
go func() {
|
||||
r.SendEvent("", []byte(`{"chunk":1}`))
|
||||
r.SendEvent("response.delta", []byte(`{"delta":"hi"}`))
|
||||
r.SendEvent("", []byte(`{"chunk":2}`))
|
||||
r.SendError([]byte(`{"error":"timeout"}`))
|
||||
r.SendDone()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for i, want := range expected {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("event %d:\n got: %q\n want: %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be EOF after all events
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected EOF, got err=%v n=%d", err, n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderRawAndWrapperMixed simulates the router.go pattern
|
||||
// where raw Send (for Bedrock/passthrough) is mixed with wrapper methods.
|
||||
func TestSSEStreamReaderRawAndWrapperMixed(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
// Simulate: Bedrock binary event (raw), followed by SSE events, then done
|
||||
bedrockBinary := []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x2A} // fake binary
|
||||
preformattedSSE := []byte("event: content_block_delta\ndata: {\"delta\":\"test\"}\n\n")
|
||||
|
||||
expected := [][]byte{
|
||||
bedrockBinary,
|
||||
preformattedSSE,
|
||||
[]byte("data: {\"final\":true}\n\n"),
|
||||
}
|
||||
|
||||
go func() {
|
||||
r.Send(bedrockBinary) // raw binary passthrough
|
||||
r.Send(preformattedSSE) // pre-formatted SSE string
|
||||
r.SendEvent("", []byte(`{"final":true}`)) // wrapper method
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for i, want := range expected {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if string(buf[:n]) != string(want) {
|
||||
t.Errorf("event %d:\n got: %q\n want: %q", i, buf[:n], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventEmptyData verifies behavior with empty data payload.
|
||||
func TestSSEStreamReaderSendEventEmptyData(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty data no event type",
|
||||
eventType: "",
|
||||
want: "data: \n\n",
|
||||
},
|
||||
{
|
||||
name: "empty data with event type",
|
||||
eventType: "heartbeat",
|
||||
want: "event: heartbeat\ndata: \n\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent(tt.eventType, []byte{})
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got := string(buf[:n]); got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventNilData verifies behavior with nil data payload.
|
||||
func TestSSEStreamReaderSendEventNilData(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent("", nil)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// nil data should produce same as empty: "data: \n\n"
|
||||
if got := string(buf[:n]); got != "data: \n\n" {
|
||||
t.Errorf("got %q, want %q", got, "data: \n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventLargePayload verifies no corruption with large JSON payloads.
|
||||
func TestSSEStreamReaderSendEventLargePayload(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
// Build a large JSON payload (~64KB, larger than typical ReadBufferSize)
|
||||
largeContent := make([]byte, 65536)
|
||||
for i := range largeContent {
|
||||
largeContent[i] = 'A' + byte(i%26)
|
||||
}
|
||||
data := append([]byte(`{"content":"`), largeContent...)
|
||||
data = append(data, '"', '}')
|
||||
|
||||
go func() {
|
||||
r.SendEvent("response.delta", data)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Read the entire event using small buffer to exercise partial reads
|
||||
var result []byte
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
result = append(result, buf[:n]...)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
want := "event: response.delta\ndata: " + string(data) + "\n\n"
|
||||
if string(result) != want {
|
||||
t.Errorf("large payload mismatch: got len=%d, want len=%d", len(result), len(want))
|
||||
// Check prefix and suffix for debugging
|
||||
if len(result) > 40 {
|
||||
t.Errorf(" got prefix: %q", result[:40])
|
||||
t.Errorf(" want prefix: %q", want[:40])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderMidStreamDisconnect simulates a client disconnecting
|
||||
// mid-stream while the producer is using SendEvent.
|
||||
func TestSSEStreamReaderMidStreamDisconnect(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
producerDone := make(chan int) // reports how many events were sent
|
||||
go func() {
|
||||
sent := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
if !r.SendEvent("", []byte(fmt.Sprintf(`{"chunk":%d}`, i))) {
|
||||
break
|
||||
}
|
||||
sent++
|
||||
}
|
||||
close(producerDone)
|
||||
}()
|
||||
|
||||
// Read a few events then simulate client disconnect
|
||||
buf := make([]byte, 4096)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Client disconnects
|
||||
r.Close()
|
||||
|
||||
// Producer should stop promptly
|
||||
<-producerDone
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendErrorThenDone verifies the handler pattern
|
||||
// of sending an error event and immediately closing the stream.
|
||||
func TestSSEStreamReaderSendErrorThenDone(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
go func() {
|
||||
// Send a few normal events
|
||||
r.SendEvent("", []byte(`{"chunk":1}`))
|
||||
r.SendEvent("", []byte(`{"chunk":2}`))
|
||||
// Error occurs, send error and stop
|
||||
r.SendError([]byte(`{"error":"rate_limit"}`))
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
expected := []string{
|
||||
"data: {\"chunk\":1}\n\n",
|
||||
"data: {\"chunk\":2}\n\n",
|
||||
"event: error\ndata: {\"error\":\"rate_limit\"}\n\n",
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("event %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be EOF (stream ended after error, no [DONE] marker)
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected EOF after error event, got err=%v n=%d data=%q", err, n, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendDoneByteExact verifies SendDone produces
|
||||
// exactly "data: [DONE]\n\n" — the standard OpenAI SSE terminator.
|
||||
func TestSSEStreamReaderSendDoneByteExact(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendDone()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Use exact-size buffer to verify no extra bytes
|
||||
want := []byte("data: [DONE]\n\n")
|
||||
buf := make([]byte, len(want))
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(want) {
|
||||
t.Errorf("expected %d bytes, got %d", len(want), n)
|
||||
}
|
||||
if string(buf[:n]) != string(want) {
|
||||
t.Errorf("got %q, want %q", buf[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderConcurrentSendEvent verifies thread safety of SendEvent
|
||||
// with multiple concurrent producers (not a real pattern but validates safety).
|
||||
func TestSSEStreamReaderConcurrentSendEvent(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
const numProducers = 5
|
||||
const eventsPerProducer = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numProducers)
|
||||
|
||||
// Launch multiple producers
|
||||
for p := 0; p < numProducers; p++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < eventsPerProducer; i++ {
|
||||
if !r.SendEvent("", []byte(fmt.Sprintf(`{"p":%d,"i":%d}`, id, i))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
|
||||
// Close after all producers finish
|
||||
go func() {
|
||||
wg.Wait()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Consume all events
|
||||
buf := make([]byte, 4096)
|
||||
count := 0
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Every event must be a valid SSE data line
|
||||
got := string(buf[:n])
|
||||
if len(got) < 8 || got[:6] != "data: " || got[len(got)-2:] != "\n\n" {
|
||||
t.Errorf("event %d: invalid SSE format: %q", count, got)
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
if count != numProducers*eventsPerProducer {
|
||||
t.Errorf("got %d events, want %d", count, numProducers*eventsPerProducer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderCloseUnblocksProducer(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
done := make(chan struct{})
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(done)
|
||||
// Fill the channel buffer (cap=1)
|
||||
r.Send([]byte("data: first\n\n"))
|
||||
// This Send should block until Close is called
|
||||
r.Send([]byte("data: second\n\n"))
|
||||
// After Close, the next Send should return false
|
||||
if r.Send([]byte("data: third\n\n")) {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("Send should return false after Close"):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Close unblocks the blocked Send
|
||||
r.Close()
|
||||
<-done
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Error(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
96
transports/bifrost-http/lib/validator.go
Normal file
96
transports/bifrost-http/lib/validator.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Package lib provides core functionality for the Bifrost HTTP service.
|
||||
// This file contains JSON schema validation for config files.
|
||||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/santhosh-tekuri/jsonschema/v6"
|
||||
)
|
||||
|
||||
// localSchemaCandidates lists paths (relative to CWD) where config.schema.json may be found
|
||||
// when running from a source checkout. Checked in order before falling back to the remote URL.
|
||||
var localSchemaCandidates = []string{
|
||||
"config.schema.json", // running from transports/
|
||||
"../config.schema.json", // running from transports/bifrost-http/
|
||||
"transports/config.schema.json", // running from repo root
|
||||
}
|
||||
|
||||
// tryLoadLocalSchema attempts to read config.schema.json from known local paths.
|
||||
// Returns nil if none are found.
|
||||
func tryLoadLocalSchema() []byte {
|
||||
for _, p := range localSchemaCandidates {
|
||||
data, err := os.ReadFile(p)
|
||||
if err == nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConfigSchema validates config data against the JSON schema.
|
||||
// Returns nil if valid, or a formatted error describing all validation failures.
|
||||
// An optional schemaOverride can be provided to use a local schema instead of fetching from the remote URL.
|
||||
func ValidateConfigSchema(data []byte, schemaOverride ...[]byte) error {
|
||||
var configSchemaJSONBytes []byte
|
||||
if len(schemaOverride) > 0 && len(schemaOverride[0]) > 0 {
|
||||
configSchemaJSONBytes = schemaOverride[0]
|
||||
} else if localSchema := tryLoadLocalSchema(); localSchema != nil {
|
||||
// Prefer the local schema file from the source checkout when available.
|
||||
// This avoids validating against a potentially stale remote schema.
|
||||
configSchemaJSONBytes = localSchema
|
||||
} else {
|
||||
// Pulling config.schema from https://www.getbifrost.ai/schema
|
||||
configSchemaJSON, err := http.Get("https://www.getbifrost.ai/schema")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config schema: %w", err)
|
||||
}
|
||||
defer configSchemaJSON.Body.Close()
|
||||
var readErr error
|
||||
configSchemaJSONBytes, readErr = io.ReadAll(configSchemaJSON.Body)
|
||||
if readErr != nil {
|
||||
logger.Warn("failed to download config schema: %v. running without config.json schema validation", readErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Parse the schema JSON
|
||||
schemaDoc, err := jsonschema.UnmarshalJSON(bytes.NewReader(configSchemaJSONBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse config schema JSON: %w", err)
|
||||
}
|
||||
c := jsonschema.NewCompiler()
|
||||
if err := c.AddResource("config.schema.json", schemaDoc); err != nil {
|
||||
return fmt.Errorf("failed to add config schema resource: %w", err)
|
||||
}
|
||||
// Compile the schema
|
||||
compiledSchema, err := c.Compile("config.schema.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compile config schema: %w", err)
|
||||
}
|
||||
var v any
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
err = compiledSchema.Validate(v)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// Format validation errors for better readability
|
||||
return formatValidationError(err)
|
||||
}
|
||||
|
||||
// formatValidationError converts jsonschema validation errors into user-friendly messages
|
||||
func formatValidationError(err error) error {
|
||||
validationErr, ok := err.(*jsonschema.ValidationError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the GoString format which provides detailed hierarchical output
|
||||
return fmt.Errorf("schema validation failed:\n%s", validationErr.GoString())
|
||||
}
|
||||
1476
transports/bifrost-http/lib/validator_test.go
Normal file
1476
transports/bifrost-http/lib/validator_test.go
Normal file
File diff suppressed because it is too large
Load Diff
162
transports/bifrost-http/main.go
Normal file
162
transports/bifrost-http/main.go
Normal file
@@ -0,0 +1,162 @@
|
||||
// Package main provides an HTTP service using FastHTTP that exposes endpoints
|
||||
// for text and chat completions using various AI model providers (OpenAI, Anthropic, Bedrock, Mistral, Ollama, etc.).
|
||||
//
|
||||
// The HTTP service provides the following main endpoints:
|
||||
// - /v1/completions: For text completion requests
|
||||
// - /v1/chat/completions: For chat completion requests
|
||||
// - /v1/mcp/tool/execute: For MCP tool execution requests
|
||||
// - /providers/*: For provider configuration management
|
||||
//
|
||||
// Configuration is handled through a JSON config file, high-performance ConfigStore, and environment variables:
|
||||
// - Use -app-dir flag to specify the application data directory (contains config.json and logs)
|
||||
// - Use -port flag to specify the server port (default: 8080)
|
||||
// - When no config file exists, common environment variables are auto-detected (OPENAI_API_KEY, ANTHROPIC_API_KEY, MISTRAL_API_KEY)
|
||||
//
|
||||
// ConfigStore Features:
|
||||
// - Pure in-memory storage for ultra-fast config access
|
||||
// - Environment variable processing for secure configuration management
|
||||
// - Real-time configuration updates via HTTP API
|
||||
// - Explicit persistence control via POST /config/save endpoint
|
||||
// - Provider-specific key config support (Azure, Bedrock, Vertex)
|
||||
// - Thread-safe operations with concurrent request handling
|
||||
// - Statistics and monitoring endpoints for operational insights
|
||||
//
|
||||
// Performance Optimizations:
|
||||
// - Configuration data is processed once during startup and stored in memory
|
||||
// - Ultra-fast memory access eliminates I/O overhead on every request
|
||||
// - All environment variable processing done upfront during configuration loading
|
||||
// - Thread-safe concurrent access with read-write mutex protection
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// go run main.go -app-dir ./data -port 8080 -host 0.0.0.0
|
||||
// after setting provider API keys like OPENAI_API_KEY in the environment.
|
||||
//
|
||||
// To bind to all interfaces for container usage, set BIFROST_HOST=0.0.0.0 or use -host 0.0.0.0
|
||||
//
|
||||
// Integration Support:
|
||||
// Bifrost supports multiple AI provider integrations through dedicated HTTP endpoints.
|
||||
// Each integration exposes API-compatible endpoints that accept the provider's native request format,
|
||||
// automatically convert it to Bifrost's unified format, process it, and return the expected response format.
|
||||
//
|
||||
// Integration endpoints follow the pattern: /{provider}/{provider_api_path}
|
||||
// Examples:
|
||||
// - OpenAI: POST /openai/v1/chat/completions (accepts OpenAI ChatCompletion requests)
|
||||
// - GenAI: POST /genai/v1beta/models/{model} (accepts Google GenAI requests)
|
||||
// - Anthropic: POST /anthropic/v1/messages (accepts Anthropic Messages requests)
|
||||
//
|
||||
// This allows clients to use their existing integration code without modification while benefiting
|
||||
// from Bifrost's unified model routing, fallbacks, monitoring capabilities, and high-performance configuration management.
|
||||
//
|
||||
// NOTE: Streaming is supported for chat completions via Server-Sent Events (SSE)
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "go.uber.org/automaxprocs" // Automatically set GOMAXPROCS based on container cgroup limits
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/handlers"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bifrostServer "github.com/maximhq/bifrost/transports/bifrost-http/server"
|
||||
)
|
||||
|
||||
//go:embed all:ui
|
||||
var uiContent embed.FS
|
||||
|
||||
var Version string
|
||||
|
||||
var logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo)
|
||||
var server *bifrostServer.BifrostHTTPServer
|
||||
|
||||
// init initializes command line flags (but does not parse them).
|
||||
// Flag parsing is deferred to main() to avoid conflicts with test flags.
|
||||
// It sets up the following flags:
|
||||
// - host: Host to bind the server to (default: localhost, can be overridden with BIFROST_HOST env var)
|
||||
// - port: Server port (default: 8080)
|
||||
// - app-dir: Application data directory (default: current directory)
|
||||
// - log-level: Logger level (debug, info, warn, error). Default is info.
|
||||
// - log-style: Logger output type (json or pretty). Default is JSON.
|
||||
|
||||
func init() {
|
||||
if Version == "" {
|
||||
Version = "v1.0.0"
|
||||
}
|
||||
// Set default host from environment variable or use localhost
|
||||
defaultHost := os.Getenv("BIFROST_HOST")
|
||||
if defaultHost == "" {
|
||||
defaultHost = bifrostServer.DefaultHost
|
||||
}
|
||||
defaultLogLevel := strings.ToLower(os.Getenv("LOG_LEVEL"))
|
||||
if defaultLogLevel == "" {
|
||||
defaultLogLevel = bifrostServer.DefaultLogLevel
|
||||
}
|
||||
// Initializing server
|
||||
server = bifrostServer.NewBifrostHTTPServer(Version, uiContent)
|
||||
// Updating server properties from flags
|
||||
flag.StringVar(&server.Port, "port", bifrostServer.DefaultPort, "Port to run the server on")
|
||||
flag.StringVar(&server.Host, "host", defaultHost, "Host to bind the server to (default: localhost, override with BIFROST_HOST env var)")
|
||||
flag.StringVar(&server.AppDir, "app-dir", bifrostServer.DefaultAppDir, "Application data directory (contains config.json and logs)")
|
||||
flag.StringVar(&server.LogLevel, "log-level", defaultLogLevel, "Logger level (debug, info, warn, error). Default is info.")
|
||||
flag.StringVar(&server.LogOutputStyle, "log-style", bifrostServer.DefaultLogOutputStyle, "Logger output type (json or pretty). Default is JSON.")
|
||||
}
|
||||
|
||||
// main is the entry point of the application.
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
flag.Parse()
|
||||
|
||||
// Printing version
|
||||
versionLine := fmt.Sprintf("║%s%s%s║", strings.Repeat(" ", (61-2-len(Version))/2), Version, strings.Repeat(" ", (61-2-len(Version)+1)/2))
|
||||
// Welcome to bifrost!
|
||||
fmt.Printf(`
|
||||
╔═══════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ ██████╗ ██╗███████╗██████╗ ██████╗ ███████╗████████╗ ║
|
||||
║ ██╔══██╗██║██╔════╝██╔══██╗██╔═══██╗██╔════╝╚══██╔══╝ ║
|
||||
║ ██████╔╝██║█████╗ ██████╔╝██║ ██║███████╗ ██║ ║
|
||||
║ ██╔══██╗██║██╔══╝ ██╔══██╗██║ ██║╚════██║ ██║ ║
|
||||
║ ██████╔╝██║██║ ██║ ██║╚██████╔╝███████║ ██║ ║
|
||||
║ ╚═════╝ ╚═╝╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═╝ ║
|
||||
║ ║
|
||||
║═══════════════════════════════════════════════════════════║
|
||||
%s
|
||||
║═══════════════════════════════════════════════════════════║
|
||||
║ The Fastest LLM Gateway ║
|
||||
║═══════════════════════════════════════════════════════════║
|
||||
║ https://github.com/maximhq/bifrost ║
|
||||
╚═══════════════════════════════════════════════════════════╝
|
||||
|
||||
`, versionLine)
|
||||
|
||||
// Configure logger from flags
|
||||
logger.SetOutputType(schemas.LoggerOutputType(server.LogOutputStyle))
|
||||
logger.SetLevel(schemas.LogLevel(server.LogLevel))
|
||||
// Setting up logger
|
||||
lib.SetLogger(logger)
|
||||
bifrostServer.SetLogger(logger)
|
||||
handlers.SetLogger(logger)
|
||||
|
||||
ctx := context.Background()
|
||||
t := time.Now()
|
||||
err := server.Bootstrap(ctx)
|
||||
if err != nil {
|
||||
logger.Error("failed to bootstrap server: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info("Time spent in Bifrost server bootstrap %d ms", time.Since(t).Milliseconds())
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
logger.Error("failed to start server: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info("🏁 server stopped")
|
||||
}
|
||||
298
transports/bifrost-http/server/plugins.go
Normal file
298
transports/bifrost-http/server/plugins.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/compat"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/plugins/logging"
|
||||
"github.com/maximhq/bifrost/plugins/maxim"
|
||||
"github.com/maximhq/bifrost/plugins/otel"
|
||||
"github.com/maximhq/bifrost/plugins/prompts"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/maximhq/bifrost/plugins/telemetry"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/handlers"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// InferPluginTypes determines which interface types a plugin implements
|
||||
func InferPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType {
|
||||
var types []schemas.PluginType
|
||||
if _, ok := plugin.(schemas.LLMPlugin); ok {
|
||||
types = append(types, schemas.PluginTypeLLM)
|
||||
}
|
||||
if _, ok := plugin.(schemas.MCPPlugin); ok {
|
||||
types = append(types, schemas.PluginTypeMCP)
|
||||
}
|
||||
if _, ok := plugin.(schemas.HTTPTransportPlugin); ok {
|
||||
types = append(types, schemas.PluginTypeHTTP)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
// Single-plugin methods used plugin create/update
|
||||
|
||||
// InstantiatePlugin creates a plugin instance but does NOT register it
|
||||
// Registration is done separately via Config.RegisterPlugin()
|
||||
func InstantiatePlugin(ctx context.Context, name string, path *string, pluginConfig any, bifrostConfig *lib.Config) (schemas.BasePlugin, error) {
|
||||
// Custom plugin (has path)
|
||||
if path != nil {
|
||||
return loadCustomPlugin(ctx, path, pluginConfig, bifrostConfig)
|
||||
}
|
||||
|
||||
// Built-in plugin (by name)
|
||||
return loadBuiltinPlugin(ctx, name, pluginConfig, bifrostConfig)
|
||||
}
|
||||
|
||||
// loadBuiltinPlugin instantiates a built-in plugin by name
|
||||
func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifrostConfig *lib.Config) (schemas.BasePlugin, error) {
|
||||
switch name {
|
||||
case telemetry.PluginName:
|
||||
telConfig := &telemetry.Config{
|
||||
CustomLabels: bifrostConfig.ClientConfig.PrometheusLabels,
|
||||
}
|
||||
// Merge push gateway config if provided (e.g., from config file or UI update)
|
||||
if pluginConfig != nil {
|
||||
extraConfig, err := MarshalPluginConfig[telemetry.Config](pluginConfig)
|
||||
if err == nil && extraConfig != nil && extraConfig.PushGateway != nil {
|
||||
telConfig.PushGateway = extraConfig.PushGateway
|
||||
}
|
||||
}
|
||||
return telemetry.Init(telConfig, bifrostConfig.ModelCatalog, logger)
|
||||
|
||||
case prompts.PluginName:
|
||||
return prompts.Init(ctx, bifrostConfig.ConfigStore, logger)
|
||||
|
||||
case logging.PluginName:
|
||||
loggingConfig, err := MarshalPluginConfig[logging.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal logging plugin config: %w", err)
|
||||
}
|
||||
return logging.Init(ctx, loggingConfig, logger, bifrostConfig.LogsStore,
|
||||
bifrostConfig.ModelCatalog, bifrostConfig.MCPCatalog)
|
||||
|
||||
case governance.PluginName:
|
||||
governanceConfig, err := MarshalPluginConfig[governance.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal governance plugin config: %w", err)
|
||||
}
|
||||
inMemoryStore := &GovernanceInMemoryStore{Config: bifrostConfig}
|
||||
return governance.Init(ctx, governanceConfig, logger, bifrostConfig.ConfigStore,
|
||||
bifrostConfig.GovernanceConfig, bifrostConfig.ModelCatalog,
|
||||
bifrostConfig.MCPCatalog, inMemoryStore)
|
||||
|
||||
case maxim.PluginName:
|
||||
maximConfig, err := MarshalPluginConfig[maxim.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal maxim plugin config: %w", err)
|
||||
}
|
||||
return maxim.Init(maximConfig, logger)
|
||||
|
||||
case semanticcache.PluginName:
|
||||
semanticConfig, err := MarshalPluginConfig[semanticcache.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal semantic cache plugin config: %w", err)
|
||||
}
|
||||
return semanticcache.Init(ctx, semanticConfig, logger, bifrostConfig.VectorStore)
|
||||
|
||||
case otel.PluginName:
|
||||
otelConfig, err := MarshalPluginConfig[otel.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal otel plugin config: %w", err)
|
||||
}
|
||||
return otel.Init(ctx, otelConfig, logger, bifrostConfig.ModelCatalog, handlers.GetVersion())
|
||||
|
||||
case compat.PluginName:
|
||||
compatConfig, err := MarshalPluginConfig[compat.Config](pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal compat plugin config: %w", err)
|
||||
}
|
||||
return compat.Init(*compatConfig, logger, bifrostConfig.ModelCatalog)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown built-in plugin: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadCustomPlugin loads a plugin from a shared object file
|
||||
func loadCustomPlugin(ctx context.Context, path *string, pluginConfig any, bifrostConfig *lib.Config) (schemas.BasePlugin, error) {
|
||||
logger.Info("loading custom plugin from path %s", *path)
|
||||
|
||||
plugin, err := bifrostConfig.PluginLoader.LoadPlugin(*path, pluginConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load custom plugin: %w", err)
|
||||
}
|
||||
return plugin, nil
|
||||
}
|
||||
|
||||
// LoadPlugins loads the plugins for the server.
|
||||
func (s *BifrostHTTPServer) LoadPlugins(ctx context.Context) error {
|
||||
// Load built-in plugins first (order matters)
|
||||
if err := s.loadBuiltinPlugins(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// Load custom plugins from config
|
||||
if err := s.loadCustomPlugins(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// Sort all plugins by placement group and order
|
||||
s.Config.SortAndRebuildPlugins()
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPluginConfig retrieves a plugin's config from PluginConfigs by name
|
||||
func (s *BifrostHTTPServer) getPluginConfig(name string) *schemas.PluginConfig {
|
||||
for _, cfg := range s.Config.PluginConfigs {
|
||||
if cfg.Name == name {
|
||||
return cfg
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadBuiltinPlugins loads required built-in plugins in specific order
|
||||
func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error {
|
||||
builtinPlacement := schemas.Ptr(schemas.PluginPlacementBuiltin)
|
||||
|
||||
// 1. Telemetry (always first - tracks everything)
|
||||
if err := s.registerPluginWithStatus(ctx, telemetry.PluginName, nil, nil, true); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(telemetry.PluginName, builtinPlacement, schemas.Ptr(1))
|
||||
|
||||
// 2. Prompts (requires config store for prompt repository; disabled in enterprise)
|
||||
if s.Config.ConfigStore != nil && ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil {
|
||||
s.registerPluginWithStatus(ctx, prompts.PluginName, nil, nil, false)
|
||||
} else {
|
||||
s.markPluginDisabled(prompts.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(prompts.PluginName, builtinPlacement, schemas.Ptr(2))
|
||||
|
||||
// 3. Logging (if enabled)
|
||||
if (s.Config.ClientConfig.EnableLogging == nil || *s.Config.ClientConfig.EnableLogging) && s.Config.LogsStore != nil {
|
||||
config := &logging.Config{
|
||||
DisableContentLogging: &s.Config.ClientConfig.DisableContentLogging,
|
||||
LoggingHeaders: &s.Config.ClientConfig.LoggingHeaders,
|
||||
}
|
||||
s.registerPluginWithStatus(ctx, logging.PluginName, nil, config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(logging.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(logging.PluginName, builtinPlacement, schemas.Ptr(3))
|
||||
|
||||
// 4. Governance (if enabled and not enterprise)
|
||||
if ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil {
|
||||
config := &governance.Config{
|
||||
IsVkMandatory: &s.Config.ClientConfig.EnforceAuthOnInference,
|
||||
RequiredHeaders: &s.Config.ClientConfig.RequiredHeaders,
|
||||
DisableAutoToolInject: &s.Config.ClientConfig.MCPDisableAutoToolInject,
|
||||
RoutingChainMaxDepth: &s.Config.ClientConfig.RoutingChainMaxDepth,
|
||||
}
|
||||
s.registerPluginWithStatus(ctx, governance.PluginName, nil, config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(governance.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(governance.PluginName, builtinPlacement, schemas.Ptr(4))
|
||||
|
||||
// 5. OTEL (if configured in PluginConfigs)
|
||||
otelConfig := s.getPluginConfig(otel.PluginName)
|
||||
if otelConfig != nil && otelConfig.Enabled {
|
||||
s.registerPluginWithStatus(ctx, otel.PluginName, nil, otelConfig.Config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(otel.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(otel.PluginName, builtinPlacement, schemas.Ptr(5))
|
||||
|
||||
// 6. Semantic Cache (if configured in PluginConfigs)
|
||||
semanticCacheConfig := s.getPluginConfig(semanticcache.PluginName)
|
||||
if semanticCacheConfig != nil && semanticCacheConfig.Enabled {
|
||||
s.registerPluginWithStatus(ctx, semanticcache.PluginName, nil, semanticCacheConfig.Config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(semanticcache.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6))
|
||||
|
||||
// 7. Compat (if any compat feature is enabled in ClientConfig)
|
||||
cc := s.Config.ClientConfig.Compat
|
||||
compatCfg := &compat.Config{
|
||||
ConvertTextToChat: cc.ConvertTextToChat,
|
||||
ConvertChatToResponses: cc.ConvertChatToResponses,
|
||||
ShouldDropParams: cc.ShouldDropParams,
|
||||
ShouldConvertParams: cc.ShouldConvertParams,
|
||||
}
|
||||
s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatCfg, false)
|
||||
s.Config.SetPluginOrderInfo(compat.PluginName, builtinPlacement, schemas.Ptr(7))
|
||||
|
||||
// 8. Maxim (if configured in PluginConfigs)
|
||||
maximConfig := s.getPluginConfig(maxim.PluginName)
|
||||
if maximConfig != nil && maximConfig.Enabled {
|
||||
s.registerPluginWithStatus(ctx, maxim.PluginName, nil, maximConfig.Config, false)
|
||||
} else {
|
||||
s.markPluginDisabled(maxim.PluginName)
|
||||
}
|
||||
s.Config.SetPluginOrderInfo(maxim.PluginName, builtinPlacement, schemas.Ptr(8))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCustomPlugins loads plugins from PluginConfigs
|
||||
func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error {
|
||||
for _, cfg := range s.Config.PluginConfigs {
|
||||
// Skip built-ins (already loaded)
|
||||
if lib.IsBuiltinPlugin(cfg.Name) {
|
||||
continue
|
||||
}
|
||||
// Handle disabled plugins
|
||||
if !cfg.Enabled {
|
||||
// For custom plugins with a path, verify to get the real plugin name
|
||||
if cfg.Path != nil {
|
||||
pluginName, err := s.Config.PluginLoader.VerifyBasePlugin(*cfg.Path)
|
||||
if err != nil {
|
||||
logger.Error("failed to verify disabled plugin %s: %v", cfg.Name, err)
|
||||
continue
|
||||
}
|
||||
// Store plugin status without instantiating (no Init() call, no resource usage)
|
||||
// Note: We can't determine types without instantiating, so pass empty slice
|
||||
s.Config.UpdatePluginOverallStatus(pluginName, cfg.Name, schemas.PluginStatusDisabled,
|
||||
[]string{fmt.Sprintf("plugin %s is disabled", cfg.Name)}, []schemas.PluginType{})
|
||||
} else {
|
||||
// Built-in plugin - use cfg.Name directly
|
||||
s.Config.UpdatePluginOverallStatus(cfg.Name, cfg.Name, schemas.PluginStatusDisabled,
|
||||
[]string{fmt.Sprintf("plugin %s is disabled", cfg.Name)}, []schemas.PluginType{})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Plugin is enabled - instantiate it
|
||||
plugin, err := InstantiatePlugin(ctx, cfg.Name, cfg.Path, cfg.Config, s.Config)
|
||||
if err != nil {
|
||||
// Skip enterprise plugins silently
|
||||
if slices.Contains(enterprisePlugins, cfg.Name) {
|
||||
continue
|
||||
}
|
||||
logger.Error("failed to load plugin %s: %v", cfg.Name, err)
|
||||
// Use cfg.Name since plugin may be nil when InstantiatePlugin returns an error
|
||||
s.Config.UpdatePluginOverallStatus(cfg.Name, cfg.Name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("error loading plugin %s: %v", cfg.Name, err)}, []schemas.PluginType{})
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure plugin is not nil before using it (defensive check)
|
||||
if plugin == nil {
|
||||
logger.Error("plugin %s instantiated but returned nil", cfg.Name)
|
||||
s.Config.UpdatePluginOverallStatus(cfg.Name, cfg.Name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("plugin %s instantiated but returned nil", cfg.Name)}, []schemas.PluginType{})
|
||||
continue
|
||||
}
|
||||
|
||||
// Register enabled plugin and mark as active
|
||||
s.Config.ReloadPlugin(plugin)
|
||||
s.Config.SetPluginOrderInfo(plugin.GetName(), cfg.Placement, cfg.Order)
|
||||
s.Config.UpdatePluginOverallStatus(plugin.GetName(), cfg.Name, schemas.PluginStatusActive,
|
||||
[]string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, InferPluginTypes(plugin))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
1544
transports/bifrost-http/server/server.go
Normal file
1544
transports/bifrost-http/server/server.go
Normal file
File diff suppressed because it is too large
Load Diff
393
transports/bifrost-http/server/server_test.go
Normal file
393
transports/bifrost-http/server/server_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// TestConfig is a sample config struct for testing
|
||||
type TestConfig struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type updateStatusOnlyConfigStore struct {
|
||||
configstore.ConfigStore
|
||||
calls []schemas.KeyStatus
|
||||
}
|
||||
|
||||
type noopTestLogger struct{}
|
||||
|
||||
func (noopTestLogger) Debug(string, ...any) {}
|
||||
func (noopTestLogger) Info(string, ...any) {}
|
||||
func (noopTestLogger) Warn(string, ...any) {}
|
||||
func (noopTestLogger) Error(string, ...any) {}
|
||||
func (noopTestLogger) Fatal(string, ...any) {}
|
||||
func (noopTestLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopTestLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func (s *updateStatusOnlyConfigStore) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error {
|
||||
s.calls = append(s.calls, schemas.KeyStatus{
|
||||
Provider: provider,
|
||||
KeyID: keyID,
|
||||
Status: schemas.KeyStatusType(status),
|
||||
Error: &schemas.BifrostError{Error: &schemas.ErrorField{Message: errorMsg}},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateKeyStatus_KeylessProviderUpdatesProviderStatusInMemory(t *testing.T) {
|
||||
prevLogger := logger
|
||||
logger = noopTestLogger{}
|
||||
defer func() { logger = prevLogger }()
|
||||
|
||||
store := &updateStatusOnlyConfigStore{}
|
||||
server := &BifrostHTTPServer{
|
||||
Config: &lib.Config{
|
||||
ConfigStore: store,
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
"mock-openai": {
|
||||
CustomProviderConfig: &schemas.CustomProviderConfig{IsKeyLess: true},
|
||||
Status: "unknown",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server.updateKeyStatus(context.Background(), []schemas.KeyStatus{{
|
||||
Provider: "mock-openai",
|
||||
KeyID: "",
|
||||
Status: schemas.KeyStatusListModelsFailed,
|
||||
Error: &schemas.BifrostError{Error: &schemas.ErrorField{Message: "preview missing model"}},
|
||||
}})
|
||||
|
||||
provider := server.Config.Providers["mock-openai"]
|
||||
if provider.Status != string(schemas.KeyStatusListModelsFailed) {
|
||||
t.Fatalf("expected provider status %q, got %q", schemas.KeyStatusListModelsFailed, provider.Status)
|
||||
}
|
||||
if provider.Description != "preview missing model" {
|
||||
t.Fatalf("expected provider description to be updated, got %q", provider.Description)
|
||||
}
|
||||
if len(store.calls) != 1 {
|
||||
t.Fatalf("expected one status update call, got %d", len(store.calls))
|
||||
}
|
||||
if store.calls[0].Provider != "mock-openai" || store.calls[0].KeyID != "" {
|
||||
t.Fatalf("expected provider-level status update, got provider=%q keyID=%q", store.calls[0].Provider, store.calls[0].KeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateKeyStatus_EmptyKeyIDDoesNotOverwriteKeyedProviderStatus(t *testing.T) {
|
||||
prevLogger := logger
|
||||
logger = noopTestLogger{}
|
||||
defer func() { logger = prevLogger }()
|
||||
|
||||
store := &updateStatusOnlyConfigStore{}
|
||||
server := &BifrostHTTPServer{
|
||||
Config: &lib.Config{
|
||||
ConfigStore: store,
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
"openai": {
|
||||
Keys: []schemas.Key{{ID: "key-1"}},
|
||||
Status: "healthy",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server.updateKeyStatus(context.Background(), []schemas.KeyStatus{{
|
||||
Provider: "openai",
|
||||
KeyID: "",
|
||||
Status: schemas.KeyStatusListModelsFailed,
|
||||
Error: &schemas.BifrostError{Error: &schemas.ErrorField{Message: "malformed status"}},
|
||||
}})
|
||||
|
||||
provider := server.Config.Providers["openai"]
|
||||
if provider.Status != "healthy" {
|
||||
t.Fatalf("expected keyed provider status to remain unchanged, got %q", provider.Status)
|
||||
}
|
||||
if provider.Description != "" {
|
||||
t.Fatalf("expected keyed provider description to remain unchanged, got %q", provider.Description)
|
||||
}
|
||||
if len(store.calls) != 1 {
|
||||
t.Fatalf("expected one status update call, got %d", len(store.calls))
|
||||
}
|
||||
if store.calls[0].Provider != "openai" || store.calls[0].KeyID != "" {
|
||||
t.Fatalf("expected DB status update to retain empty key ID, got provider=%q keyID=%q", store.calls[0].Provider, store.calls[0].KeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithPointerType(t *testing.T) {
|
||||
// Test case 1: source is already *T
|
||||
expected := &TestConfig{
|
||||
Name: "test-plugin",
|
||||
Enabled: true,
|
||||
Count: 42,
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](expected)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected same pointer, got different pointer")
|
||||
}
|
||||
|
||||
if result.Name != expected.Name {
|
||||
t.Errorf("Expected Name=%s, got %s", expected.Name, result.Name)
|
||||
}
|
||||
if result.Enabled != expected.Enabled {
|
||||
t.Errorf("Expected Enabled=%v, got %v", expected.Enabled, result.Enabled)
|
||||
}
|
||||
if result.Count != expected.Count {
|
||||
t.Errorf("Expected Count=%d, got %d", expected.Count, result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithMap(t *testing.T) {
|
||||
// Test case 2: source is map[string]any
|
||||
configMap := map[string]any{
|
||||
"name": "test-plugin",
|
||||
"enabled": true,
|
||||
"count": 42,
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Name != "test-plugin" {
|
||||
t.Errorf("Expected Name=test-plugin, got %s", result.Name)
|
||||
}
|
||||
if result.Enabled != true {
|
||||
t.Errorf("Expected Enabled=true, got %v", result.Enabled)
|
||||
}
|
||||
if result.Count != 42 {
|
||||
t.Errorf("Expected Count=42, got %d", result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithString(t *testing.T) {
|
||||
// Test case 3: source is string (JSON)
|
||||
configStr := `{"name":"test-plugin","enabled":true,"count":42}`
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Name != "test-plugin" {
|
||||
t.Errorf("Expected Name=test-plugin, got %s", result.Name)
|
||||
}
|
||||
if result.Enabled != true {
|
||||
t.Errorf("Expected Enabled=true, got %v", result.Enabled)
|
||||
}
|
||||
if result.Count != 42 {
|
||||
t.Errorf("Expected Count=42, got %d", result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithInvalidType(t *testing.T) {
|
||||
// Test case 4: source is invalid type (should return error)
|
||||
invalidSource := 12345
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](invalidSource)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid type, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for invalid type, got %v", result)
|
||||
}
|
||||
|
||||
expectedError := "invalid config type"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithInvalidJSONString(t *testing.T) {
|
||||
// Test case 5: source is string but invalid JSON
|
||||
invalidJSON := `{"name":"test-plugin","enabled":true,count:42}` // missing quotes around count
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](invalidJSON)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid JSON, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for invalid JSON, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithInvalidMapData(t *testing.T) {
|
||||
// Test case 6: source is map but contains invalid data types
|
||||
configMap := map[string]any{
|
||||
"name": "test-plugin",
|
||||
"enabled": "not-a-boolean", // wrong type
|
||||
"count": 42,
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configMap)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid map data, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for invalid map data, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithEmptyMap(t *testing.T) {
|
||||
// Test case 7: source is empty map (should work, return zero values)
|
||||
configMap := map[string]any{}
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error for empty map, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
// All fields should have zero values
|
||||
if result.Name != "" {
|
||||
t.Errorf("Expected empty Name, got %s", result.Name)
|
||||
}
|
||||
if result.Enabled != false {
|
||||
t.Errorf("Expected Enabled=false, got %v", result.Enabled)
|
||||
}
|
||||
if result.Count != 0 {
|
||||
t.Errorf("Expected Count=0, got %d", result.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithEmptyString(t *testing.T) {
|
||||
// Test case 8: source is empty string (should fail as invalid JSON)
|
||||
configStr := ""
|
||||
|
||||
result, err := MarshalPluginConfig[TestConfig](configStr)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty string, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for empty string, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithNil(t *testing.T) {
|
||||
// Test case 9: source is nil (should return error as invalid type)
|
||||
result, err := MarshalPluginConfig[TestConfig](nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nil source, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result for nil source, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMarshalPluginConfig_WithPointerType(b *testing.B) {
|
||||
config := &TestConfig{
|
||||
Name: "test-plugin",
|
||||
Enabled: true,
|
||||
Count: 42,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MarshalPluginConfig[TestConfig](config)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalPluginConfig_WithMap(b *testing.B) {
|
||||
configMap := map[string]any{
|
||||
"name": "test-plugin",
|
||||
"enabled": true,
|
||||
"count": 42,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MarshalPluginConfig[TestConfig](configMap)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalPluginConfig_WithString(b *testing.B) {
|
||||
configStr := `{"name":"test-plugin","enabled":true,"count":42}`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = MarshalPluginConfig[TestConfig](configStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Complex config for additional testing
|
||||
type ComplexConfig struct {
|
||||
Settings map[string]string `json:"settings"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
Nested *TestConfig `json:"nested"`
|
||||
}
|
||||
|
||||
func TestMarshalPluginConfig_WithComplexType(t *testing.T) {
|
||||
// Test with a more complex nested structure
|
||||
configMap := map[string]any{
|
||||
"settings": map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
"tags": []any{"tag1", "tag2", "tag3"},
|
||||
"metadata": map[string]any{
|
||||
"version": "1.0.0",
|
||||
"author": "test",
|
||||
},
|
||||
"nested": map[string]any{
|
||||
"name": "nested-config",
|
||||
"enabled": true,
|
||||
"count": 10,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := MarshalPluginConfig[ComplexConfig](configMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil result")
|
||||
}
|
||||
|
||||
if len(result.Settings) != 2 {
|
||||
t.Errorf("Expected 2 settings, got %d", len(result.Settings))
|
||||
}
|
||||
if len(result.Tags) != 3 {
|
||||
t.Errorf("Expected 3 tags, got %d", len(result.Tags))
|
||||
}
|
||||
if result.Nested == nil {
|
||||
t.Fatal("Expected non-nil nested config")
|
||||
}
|
||||
if result.Nested.Name != "nested-config" {
|
||||
t.Errorf("Expected nested name=nested-config, got %s", result.Nested.Name)
|
||||
}
|
||||
}
|
||||
208
transports/bifrost-http/server/utils.go
Normal file
208
transports/bifrost-http/server/utils.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// GetDefaultConfigDir returns the OS-specific default configuration directory for Bifrost.
|
||||
// This follows standard conventions:
|
||||
// - Linux/macOS: ~/.config/bifrost
|
||||
// - Windows: %APPDATA%\bifrost
|
||||
// - If appDir is provided (non-empty), it returns that instead
|
||||
func GetDefaultConfigDir(appDir string) string {
|
||||
// If appDir is provided, use it directly
|
||||
if appDir != "" {
|
||||
return appDir
|
||||
}
|
||||
|
||||
// Get OS-specific config directory
|
||||
var configDir string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
// Windows: %APPDATA%\bifrost
|
||||
if appData := os.Getenv("APPDATA"); appData != "" {
|
||||
configDir = filepath.Join(appData, "bifrost")
|
||||
} else {
|
||||
// Fallback to user home directory
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
configDir = filepath.Join(homeDir, "AppData", "Roaming", "bifrost")
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Linux, macOS and other Unix-like systems: ~/.config/bifrost
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
configDir = filepath.Join(homeDir, ".config", "bifrost")
|
||||
}
|
||||
}
|
||||
|
||||
// If we couldn't determine the config directory, fall back to current directory
|
||||
if configDir == "" {
|
||||
configDir = "./bifrost-data"
|
||||
}
|
||||
|
||||
return configDir
|
||||
}
|
||||
|
||||
// registerPluginWithStatus instantiates, registers, and updates status for a plugin (used by builtin plugins)
|
||||
func (s *BifrostHTTPServer) registerPluginWithStatus(ctx context.Context, name string, path *string, config any, failOnError bool) error {
|
||||
plugin, err := InstantiatePlugin(ctx, name, path, config, s.Config)
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize %s plugin: %v", name, err)
|
||||
// Use name since plugin may be nil when InstantiatePlugin returns an error
|
||||
s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("error initializing %s plugin: %v", name, err)}, []schemas.PluginType{})
|
||||
if failOnError {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure plugin is not nil before using it (defensive check)
|
||||
if plugin == nil {
|
||||
logger.Error("plugin %s instantiated but returned nil", name)
|
||||
s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusError,
|
||||
[]string{fmt.Sprintf("plugin %s instantiated but returned nil", name)}, []schemas.PluginType{})
|
||||
if failOnError {
|
||||
return fmt.Errorf("plugin %s instantiated but returned nil", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.Config.ReloadPlugin(plugin)
|
||||
s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusActive,
|
||||
[]string{fmt.Sprintf("%s plugin initialized successfully", name)}, InferPluginTypes(plugin))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectObservabilityPlugins gathers all loaded plugins that implement ObservabilityPlugin interface
|
||||
func (s *BifrostHTTPServer) CollectObservabilityPlugins() []schemas.ObservabilityPlugin {
|
||||
var observabilityPlugins []schemas.ObservabilityPlugin
|
||||
|
||||
// Check LLM plugins
|
||||
for _, plugin := range s.Config.GetLoadedLLMPlugins() {
|
||||
if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok {
|
||||
observabilityPlugins = append(observabilityPlugins, observabilityPlugin)
|
||||
}
|
||||
}
|
||||
|
||||
// Check MCP plugins
|
||||
for _, plugin := range s.Config.GetLoadedMCPPlugins() {
|
||||
if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok {
|
||||
observabilityPlugins = append(observabilityPlugins, observabilityPlugin)
|
||||
}
|
||||
}
|
||||
|
||||
return observabilityPlugins
|
||||
}
|
||||
|
||||
// MarshalPluginConfig marshals the plugin configuration
|
||||
func MarshalPluginConfig[T any](source any) (*T, error) {
|
||||
// If its a *T, then we will confirm
|
||||
if config, ok := source.(*T); ok {
|
||||
return config, nil
|
||||
}
|
||||
// Initialize a new instance for unmarshaling
|
||||
config := new(T)
|
||||
// If its a map[string]any, then we will JSON parse and confirm
|
||||
if configMap, ok := source.(map[string]any); ok {
|
||||
configString, err := sonic.Marshal(configMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := sonic.Unmarshal([]byte(configString), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
// If its a string, then we will JSON parse and confirm
|
||||
if configStr, ok := source.(string); ok {
|
||||
if err := sonic.Unmarshal([]byte(configStr), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid config type")
|
||||
}
|
||||
|
||||
// updateKeyStatus updates the model discovery status for keys or providers based on key statuses.
|
||||
// For keyed providers: updates individual key status
|
||||
// For keyless providers: updates provider-level status
|
||||
func (s *BifrostHTTPServer) updateKeyStatus(
|
||||
ctx context.Context,
|
||||
keyStatuses []schemas.KeyStatus,
|
||||
) {
|
||||
if s.Config == nil || s.Config.ConfigStore == nil || len(keyStatuses) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Update each key/provider status individually
|
||||
for _, ks := range keyStatuses {
|
||||
errorMsg := ""
|
||||
if ks.Error != nil && ks.Error.Error != nil {
|
||||
errorMsg = ks.Error.Error.Message
|
||||
}
|
||||
|
||||
if err := s.Config.ConfigStore.UpdateStatus(ctx, ks.Provider, ks.KeyID, string(ks.Status), errorMsg); err != nil {
|
||||
target := ks.KeyID
|
||||
if target == "" {
|
||||
target = string(ks.Provider)
|
||||
}
|
||||
logger.Error("failed to update model discovery status for %s: %v", target, err)
|
||||
continue // Skip in-memory update if DB update failed
|
||||
}
|
||||
|
||||
s.Config.Mu.Lock()
|
||||
|
||||
providerConfig, exists := s.Config.Providers[ks.Provider]
|
||||
if !exists {
|
||||
s.Config.Mu.Unlock()
|
||||
logger.Warn("provider %s not found in memory during status update", ks.Provider)
|
||||
continue
|
||||
}
|
||||
|
||||
isKeylessProvider := providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess
|
||||
|
||||
if ks.KeyID == "" {
|
||||
if !isKeylessProvider {
|
||||
logger.Warn("received provider-level status update for keyed provider %s; skipping in-memory update", ks.Provider)
|
||||
s.Config.Mu.Unlock()
|
||||
continue
|
||||
}
|
||||
providerConfig.Status = string(ks.Status)
|
||||
providerConfig.Description = errorMsg
|
||||
s.Config.Providers[ks.Provider] = providerConfig
|
||||
logger.Debug("updated in-memory status for keyless provider %s", ks.Provider)
|
||||
s.Config.Mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Find and update the specific key in the Keys slice
|
||||
updated := false
|
||||
for i := range providerConfig.Keys {
|
||||
if providerConfig.Keys[i].ID == ks.KeyID {
|
||||
// Update Status and Description fields
|
||||
providerConfig.Keys[i].Status = ks.Status
|
||||
providerConfig.Keys[i].Description = errorMsg
|
||||
updated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if updated {
|
||||
// Write the modified config back to the map
|
||||
s.Config.Providers[ks.Provider] = providerConfig
|
||||
logger.Debug("updated in-memory status for key %s of provider %s", ks.KeyID, ks.Provider)
|
||||
} else {
|
||||
logger.Warn("key %s not found in provider %s during in-memory update", ks.KeyID, ks.Provider)
|
||||
}
|
||||
|
||||
s.Config.Mu.Unlock()
|
||||
}
|
||||
}
|
||||
140
transports/bifrost-http/websocket/connection.go
Normal file
140
transports/bifrost-http/websocket/connection.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Package websocket provides upstream WebSocket connection management for the Bifrost gateway.
|
||||
// It manages pooled connections to provider WebSocket APIs (e.g., OpenAI Responses WS mode,
|
||||
// Realtime API) and client session bindings.
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// UpstreamConn wraps a WebSocket connection to an upstream provider.
|
||||
// Thread-safe for concurrent read/write via separate mutexes.
|
||||
type UpstreamConn struct {
|
||||
conn *ws.Conn
|
||||
provider schemas.ModelProvider
|
||||
keyID string
|
||||
endpoint string
|
||||
createdAt time.Time
|
||||
lastUsed atomic.Int64 // unix nano
|
||||
|
||||
writeMu sync.Mutex
|
||||
readMu sync.Mutex
|
||||
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// newUpstreamConn creates a new UpstreamConn wrapping the given websocket connection.
|
||||
func newUpstreamConn(conn *ws.Conn, provider schemas.ModelProvider, keyID, endpoint string) *UpstreamConn {
|
||||
uc := &UpstreamConn{
|
||||
conn: conn,
|
||||
provider: provider,
|
||||
keyID: keyID,
|
||||
endpoint: endpoint,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
uc.lastUsed.Store(time.Now().UnixNano())
|
||||
return uc
|
||||
}
|
||||
|
||||
// WriteMessage sends a message to the upstream provider. Thread-safe.
|
||||
func (c *UpstreamConn) WriteMessage(messageType int, data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// WriteJSON sends a JSON-encoded message to the upstream provider. Thread-safe.
|
||||
func (c *UpstreamConn) WriteJSON(v interface{}) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
// ReadMessage reads a message from the upstream provider. Thread-safe.
|
||||
func (c *UpstreamConn) ReadMessage() (messageType int, p []byte, err error) {
|
||||
c.readMu.Lock()
|
||||
defer c.readMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.ReadMessage()
|
||||
}
|
||||
|
||||
// Close closes the underlying WebSocket connection.
|
||||
func (c *UpstreamConn) Close() error {
|
||||
if c.closed.CompareAndSwap(false, true) {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connection has been closed.
|
||||
func (c *UpstreamConn) IsClosed() bool {
|
||||
return c.closed.Load()
|
||||
}
|
||||
|
||||
// Provider returns the provider this connection is for.
|
||||
func (c *UpstreamConn) Provider() schemas.ModelProvider {
|
||||
return c.provider
|
||||
}
|
||||
|
||||
// KeyID returns the API key ID used for this connection.
|
||||
func (c *UpstreamConn) KeyID() string {
|
||||
return c.keyID
|
||||
}
|
||||
|
||||
// CreatedAt returns when this connection was established.
|
||||
func (c *UpstreamConn) CreatedAt() time.Time {
|
||||
return c.createdAt
|
||||
}
|
||||
|
||||
// LastUsed returns the last time this connection was used.
|
||||
func (c *UpstreamConn) LastUsed() time.Time {
|
||||
return time.Unix(0, c.lastUsed.Load())
|
||||
}
|
||||
|
||||
// Age returns how long this connection has been alive.
|
||||
func (c *UpstreamConn) Age() time.Duration {
|
||||
return time.Since(c.createdAt)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||
func (c *UpstreamConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||
func (c *UpstreamConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetPongHandler sets a handler for pong messages received from the upstream.
|
||||
func (c *UpstreamConn) SetPongHandler(h func(appData string) error) {
|
||||
c.conn.SetPongHandler(h)
|
||||
}
|
||||
|
||||
// WritePing sends a ping message to the upstream. Thread-safe.
|
||||
func (c *UpstreamConn) WritePing(data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.WriteMessage(ws.PingMessage, data)
|
||||
}
|
||||
|
||||
// Dial creates a new WebSocket connection to the given URL with the provided headers.
|
||||
func Dial(url string, headers map[string]string) (*ws.Conn, *http.Response, error) {
|
||||
dialer := ws.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
h := http.Header{}
|
||||
for k, v := range headers {
|
||||
h.Set(k, v)
|
||||
}
|
||||
return dialer.Dial(url, h)
|
||||
}
|
||||
8
transports/bifrost-http/websocket/errors.go
Normal file
8
transports/bifrost-http/websocket/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package websocket
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConnectionLimitReached = errors.New("websocket connection limit reached")
|
||||
ErrPoolClosed = errors.New("websocket pool is closed")
|
||||
)
|
||||
221
transports/bifrost-http/websocket/pool.go
Normal file
221
transports/bifrost-http/websocket/pool.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// PoolKey uniquely identifies a group of upstream connections.
|
||||
type PoolKey struct {
|
||||
Provider schemas.ModelProvider
|
||||
KeyID string
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
// Pool manages a pool of upstream WebSocket connections keyed by (provider, keyID, endpoint).
|
||||
// Idle connections are cached for reuse. Connections exceeding max lifetime are discarded.
|
||||
type Pool struct {
|
||||
mu sync.Mutex
|
||||
idle map[PoolKey][]*UpstreamConn
|
||||
inFlight int
|
||||
|
||||
config *schemas.WSPoolConfig
|
||||
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewPool creates a new upstream WebSocket connection pool.
|
||||
func NewPool(config *schemas.WSPoolConfig) *Pool {
|
||||
if config == nil {
|
||||
config = &schemas.WSPoolConfig{}
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
p := &Pool{
|
||||
idle: make(map[PoolKey][]*UpstreamConn),
|
||||
config: config,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go p.evictLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get retrieves an idle connection for the given key, or dials a new one.
|
||||
// The returned connection is removed from the idle pool and must be returned
|
||||
// via Return or discarded via Discard.
|
||||
func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("pool is closed")
|
||||
}
|
||||
|
||||
conns := p.idle[key]
|
||||
for len(conns) > 0 {
|
||||
// Pop from the back (most recently returned)
|
||||
conn := conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
p.idle[key] = conns
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
if conn.IsClosed() || p.isExpired(conn) {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
conns = p.idle[key]
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.inFlight++
|
||||
p.mu.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check total capacity (idle + in-flight) before dialing
|
||||
totalIdle := 0
|
||||
for _, c := range p.idle {
|
||||
totalIdle += len(c)
|
||||
}
|
||||
if totalIdle+p.inFlight >= p.config.MaxTotalConnections {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("pool capacity exhausted: %d idle + %d in-flight >= %d max", totalIdle, p.inFlight, p.config.MaxTotalConnections)
|
||||
}
|
||||
|
||||
// Reserve a slot before unlocking to dial
|
||||
p.inFlight++
|
||||
p.mu.Unlock()
|
||||
|
||||
conn, err := p.dial(key, headers)
|
||||
if err != nil {
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Return puts a connection back into the idle pool for reuse.
|
||||
// If the connection is expired or the pool is full, it is closed instead.
|
||||
func (p *Pool) Return(conn *UpstreamConn) {
|
||||
if conn == nil || conn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if p.isExpired(conn) {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
key := PoolKey{
|
||||
Provider: conn.provider,
|
||||
KeyID: conn.keyID,
|
||||
Endpoint: conn.endpoint,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.inFlight--
|
||||
|
||||
if p.closed {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
conns := p.idle[key]
|
||||
if len(conns) >= p.config.MaxIdlePerKey {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
p.idle[key] = append(conns, conn)
|
||||
}
|
||||
|
||||
// Discard closes a connection without returning it to the pool.
|
||||
func (p *Pool) Discard(conn *UpstreamConn) {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and closes all idle connections.
|
||||
func (p *Pool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.done)
|
||||
|
||||
for key, conns := range p.idle {
|
||||
for _, conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
delete(p.idle, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) dial(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
||||
wsConn, _, err := Dial(key.Endpoint, headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial upstream websocket %s: %w", key.Endpoint, err)
|
||||
}
|
||||
return newUpstreamConn(wsConn, key.Provider, key.KeyID, key.Endpoint), nil
|
||||
}
|
||||
|
||||
func (p *Pool) isExpired(conn *UpstreamConn) bool {
|
||||
maxLifetime := time.Duration(p.config.MaxConnectionLifetimeSeconds) * time.Second
|
||||
if conn.Age() >= maxLifetime {
|
||||
return true
|
||||
}
|
||||
idleTimeout := time.Duration(p.config.IdleTimeoutSeconds) * time.Second
|
||||
return time.Since(conn.LastUsed()) >= idleTimeout
|
||||
}
|
||||
|
||||
// evictLoop periodically removes expired idle connections.
|
||||
func (p *Pool) evictLoop() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.evictExpired()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) evictExpired() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for key, conns := range p.idle {
|
||||
alive := conns[:0]
|
||||
for _, conn := range conns {
|
||||
if conn.IsClosed() || p.isExpired(conn) {
|
||||
conn.Close()
|
||||
} else {
|
||||
alive = append(alive, conn)
|
||||
}
|
||||
}
|
||||
if len(alive) == 0 {
|
||||
delete(p.idle, key)
|
||||
} else {
|
||||
p.idle[key] = alive
|
||||
}
|
||||
}
|
||||
}
|
||||
160
transports/bifrost-http/websocket/pool_test.go
Normal file
160
transports/bifrost-http/websocket/pool_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func startTestWSServer(t *testing.T) *httptest.Server {
|
||||
t.Helper()
|
||||
upgrader := ws.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
conn.WriteMessage(mt, msg)
|
||||
}
|
||||
}))
|
||||
return server
|
||||
}
|
||||
|
||||
func TestPoolGetAndReturn(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 300,
|
||||
MaxConnectionLifetimeSeconds: 3600,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
// Get a new connection (pool is empty, should dial)
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
assert.Equal(t, schemas.OpenAI, conn.Provider())
|
||||
assert.Equal(t, "test-key", conn.KeyID())
|
||||
assert.False(t, conn.IsClosed())
|
||||
|
||||
// Return to pool
|
||||
pool.Return(conn)
|
||||
|
||||
// Get again — should reuse the same connection
|
||||
conn2, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
assert.Same(t, conn, conn2)
|
||||
pool.Return(conn2)
|
||||
}
|
||||
|
||||
func TestPoolMaxIdlePerKey(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 2,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 300,
|
||||
MaxConnectionLifetimeSeconds: 3600,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
// Get 3 connections
|
||||
var conns []*UpstreamConn
|
||||
for range 3 {
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// Return all 3 — only 2 should be kept (MaxIdlePerKey=2)
|
||||
for _, conn := range conns {
|
||||
pool.Return(conn)
|
||||
}
|
||||
|
||||
pool.mu.Lock()
|
||||
idleCount := len(pool.idle[key])
|
||||
pool.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 2, idleCount)
|
||||
}
|
||||
|
||||
func TestPoolClose(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{}
|
||||
config.CheckAndSetDefaults()
|
||||
pool := NewPool(config)
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
pool.Return(conn)
|
||||
|
||||
pool.Close()
|
||||
|
||||
// Getting from a closed pool should fail
|
||||
_, err = pool.Get(key, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPoolExpiredConnection(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 1,
|
||||
MaxConnectionLifetimeSeconds: 1,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
pool.Return(conn)
|
||||
|
||||
// Wait for connection to expire
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Get should dial a new connection (old one expired)
|
||||
conn2, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
assert.NotSame(t, conn, conn2)
|
||||
pool.Discard(conn2)
|
||||
}
|
||||
450
transports/bifrost-http/websocket/session.go
Normal file
450
transports/bifrost-http/websocket/session.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Session tracks the binding between a client WebSocket connection and its upstream state.
|
||||
// For Responses WS mode, it tracks previous_response_id → upstream connection pinning.
|
||||
type Session struct {
|
||||
mu sync.RWMutex
|
||||
writeMu sync.Mutex // serializes all WriteMessage calls to clientConn
|
||||
|
||||
id string
|
||||
|
||||
// Client connection
|
||||
clientConn *ws.Conn
|
||||
|
||||
// Upstream connection currently pinned to this session (for native WS mode).
|
||||
// nil when using HTTP bridge.
|
||||
upstream *UpstreamConn
|
||||
|
||||
// LastResponseID tracks the most recent response ID for previous_response_id chaining.
|
||||
lastResponseID string
|
||||
|
||||
// providerSessionID tracks the upstream provider's session identifier when exposed.
|
||||
providerSessionID string
|
||||
|
||||
// realtimeOutputText accumulates assistant/provider turn text until the terminal event.
|
||||
realtimeOutputText string
|
||||
|
||||
// realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the
|
||||
// completed assistant turn can persist the full turn history instead of only the
|
||||
// latest finalized input event.
|
||||
realtimeTurnInputs []RealtimeTurnInput
|
||||
|
||||
// realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been
|
||||
// attached to a persisted turn, so late transcript updates do not pollute later turns.
|
||||
realtimeConsumedTurnItemIDs map[string]struct{}
|
||||
|
||||
// realtimeTurnHooks tracks the active turn-scoped plugin pipeline between
|
||||
// response.create and response.done.
|
||||
realtimeTurnHooks *RealtimeTurnPluginState
|
||||
realtimeTurnBusy bool
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
type RealtimeToolOutput struct {
|
||||
Summary string
|
||||
Raw string
|
||||
}
|
||||
|
||||
type RealtimeTurnInput struct {
|
||||
ItemID string
|
||||
Role string
|
||||
Summary string
|
||||
Raw string
|
||||
}
|
||||
|
||||
type RealtimeTurnPluginState struct {
|
||||
PostHookRunner schemas.PostHookRunner
|
||||
Cleanup func()
|
||||
RequestID string
|
||||
StartedAt time.Time
|
||||
PreHookValues map[any]any
|
||||
}
|
||||
|
||||
// NewSession creates a new session for a client WebSocket connection.
|
||||
func NewSession(clientConn *ws.Conn) *Session {
|
||||
return &Session{
|
||||
id: uuid.NewString(),
|
||||
clientConn: clientConn,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the stable Bifrost session identifier for this websocket session.
|
||||
func (s *Session) ID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.id
|
||||
}
|
||||
|
||||
// ClientConn returns the client's WebSocket connection.
|
||||
func (s *Session) ClientConn() *ws.Conn {
|
||||
return s.clientConn
|
||||
}
|
||||
|
||||
// WriteMessage sends a message to the client WebSocket connection.
|
||||
// It serializes concurrent writes via writeMu to prevent panics from
|
||||
// simultaneous goroutine writes (e.g., heartbeat vs streaming relay).
|
||||
func (s *Session) WriteMessage(messageType int, data []byte) error {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
return s.clientConn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// SetUpstream pins an upstream connection to this session.
|
||||
func (s *Session) SetUpstream(conn *UpstreamConn) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
if s.upstream != nil && s.upstream != conn {
|
||||
s.upstream.Close()
|
||||
}
|
||||
s.upstream = conn
|
||||
}
|
||||
|
||||
// Upstream returns the currently pinned upstream connection, or nil.
|
||||
func (s *Session) Upstream() *UpstreamConn {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.upstream
|
||||
}
|
||||
|
||||
// SetLastResponseID updates the last response ID for chaining.
|
||||
func (s *Session) SetLastResponseID(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lastResponseID = id
|
||||
}
|
||||
|
||||
// LastResponseID returns the last response ID.
|
||||
func (s *Session) LastResponseID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.lastResponseID
|
||||
}
|
||||
|
||||
// SetProviderSessionID stores the upstream provider session identifier when available.
|
||||
func (s *Session) SetProviderSessionID(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.providerSessionID = id
|
||||
}
|
||||
|
||||
// ProviderSessionID returns the upstream provider session identifier when known.
|
||||
func (s *Session) ProviderSessionID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.providerSessionID
|
||||
}
|
||||
|
||||
// AppendRealtimeOutputText appends provider output content for the current realtime turn.
|
||||
func (s *Session) AppendRealtimeOutputText(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeOutputText += text
|
||||
}
|
||||
|
||||
// ConsumeRealtimeOutputText returns the accumulated provider output and clears it.
|
||||
func (s *Session) ConsumeRealtimeOutputText() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
text := s.realtimeOutputText
|
||||
s.realtimeOutputText = ""
|
||||
return text
|
||||
}
|
||||
|
||||
// AddRealtimeInput stores a finalized user turn event in arrival order.
|
||||
func (s *Session) AddRealtimeInput(summary, raw string) {
|
||||
if summary == "" && raw == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
||||
Role: string(schemas.ChatMessageRoleUser),
|
||||
Summary: summary,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID.
|
||||
// Late updates for items already attached to a completed turn are ignored.
|
||||
func (s *Session) RecordRealtimeInput(itemID, summary, raw string) {
|
||||
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw)
|
||||
}
|
||||
|
||||
// AddRealtimeToolOutput stores a pending tool result for the next assistant turn.
|
||||
func (s *Session) AddRealtimeToolOutput(summary, raw string) {
|
||||
if summary == "" && raw == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
||||
Role: string(schemas.ChatMessageRoleTool),
|
||||
Summary: summary,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID.
|
||||
// Late updates for items already attached to a completed turn are ignored.
|
||||
func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) {
|
||||
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw)
|
||||
}
|
||||
|
||||
func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) {
|
||||
if summary == "" && raw == "" {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
itemID = strings.TrimSpace(itemID)
|
||||
if itemID != "" {
|
||||
if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed {
|
||||
return
|
||||
}
|
||||
for idx := range s.realtimeTurnInputs {
|
||||
if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(summary) != "" {
|
||||
s.realtimeTurnInputs[idx].Summary = summary
|
||||
}
|
||||
if strings.TrimSpace(raw) != "" {
|
||||
existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw)
|
||||
incomingRaw := strings.TrimSpace(raw)
|
||||
switch {
|
||||
case existingRaw == "":
|
||||
s.realtimeTurnInputs[idx].Raw = raw
|
||||
case incomingRaw == "" || existingRaw == incomingRaw:
|
||||
default:
|
||||
s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
||||
ItemID: itemID,
|
||||
Role: role,
|
||||
Summary: summary,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
|
||||
// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them.
|
||||
func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
|
||||
if len(inputs) > 0 {
|
||||
if s.realtimeConsumedTurnItemIDs == nil {
|
||||
s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs))
|
||||
}
|
||||
for _, input := range inputs {
|
||||
if strings.TrimSpace(input.ItemID) != "" {
|
||||
s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.realtimeTurnInputs = nil
|
||||
return inputs
|
||||
}
|
||||
|
||||
// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them.
|
||||
func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
|
||||
}
|
||||
|
||||
// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline.
|
||||
func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
|
||||
s.realtimeTurnHooks.Cleanup()
|
||||
}
|
||||
s.realtimeTurnBusy = false
|
||||
if s.closed {
|
||||
if state != nil && state.Cleanup != nil {
|
||||
state.Cleanup()
|
||||
}
|
||||
s.realtimeTurnHooks = nil
|
||||
return
|
||||
}
|
||||
s.realtimeTurnHooks = state
|
||||
}
|
||||
|
||||
// TryBeginRealtimeTurnHooks reserves the single active turn slot.
|
||||
func (s *Session) TryBeginRealtimeTurnHooks() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil {
|
||||
return false
|
||||
}
|
||||
s.realtimeTurnBusy = true
|
||||
return true
|
||||
}
|
||||
|
||||
// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks.
|
||||
func (s *Session) AbortRealtimeTurnHooks() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeTurnBusy = false
|
||||
}
|
||||
|
||||
// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it.
|
||||
func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.realtimeTurnHooks
|
||||
}
|
||||
|
||||
// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it.
|
||||
func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
state := s.realtimeTurnHooks
|
||||
s.realtimeTurnHooks = nil
|
||||
s.realtimeTurnBusy = false
|
||||
return state
|
||||
}
|
||||
|
||||
// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline.
|
||||
func (s *Session) ClearRealtimeTurnHooks() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
|
||||
s.realtimeTurnHooks.Cleanup()
|
||||
}
|
||||
s.realtimeTurnHooks = nil
|
||||
s.realtimeTurnBusy = false
|
||||
}
|
||||
|
||||
// Close closes the session and its upstream connection if pinned.
|
||||
func (s *Session) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
if s.realtimeTurnHooks != nil {
|
||||
if s.realtimeTurnHooks.Cleanup != nil {
|
||||
s.realtimeTurnHooks.Cleanup()
|
||||
}
|
||||
s.realtimeTurnHooks = nil
|
||||
}
|
||||
s.realtimeTurnBusy = false
|
||||
if s.clientConn != nil {
|
||||
_ = s.clientConn.Close()
|
||||
}
|
||||
if s.upstream != nil {
|
||||
s.upstream.Close()
|
||||
s.upstream = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SessionManager tracks active sessions for connection limiting and cleanup.
|
||||
type SessionManager struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[*ws.Conn]*Session
|
||||
maxConns int
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager.
|
||||
func NewSessionManager(maxConns int) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[*ws.Conn]*Session),
|
||||
maxConns: maxConns,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates and registers a new session for the given client connection.
|
||||
// Returns an error if the connection limit would be exceeded.
|
||||
func (m *SessionManager) Create(clientConn *ws.Conn) (*Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.maxConns > 0 && len(m.sessions) >= m.maxConns {
|
||||
return nil, ErrConnectionLimitReached
|
||||
}
|
||||
|
||||
session := NewSession(clientConn)
|
||||
m.sessions[clientConn] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Get returns the session for the given client connection.
|
||||
func (m *SessionManager) Get(clientConn *ws.Conn) *Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.sessions[clientConn]
|
||||
}
|
||||
|
||||
// Remove removes and closes a session.
|
||||
func (m *SessionManager) Remove(clientConn *ws.Conn) {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[clientConn]
|
||||
if ok {
|
||||
delete(m.sessions, clientConn)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if session != nil {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of active sessions.
|
||||
func (m *SessionManager) Count() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
// CloseAll closes all active sessions.
|
||||
func (m *SessionManager) CloseAll() {
|
||||
m.mu.Lock()
|
||||
sessions := m.sessions
|
||||
m.sessions = make(map[*ws.Conn]*Session)
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, session := range sessions {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the currently tracked sessions.
|
||||
func (m *SessionManager) Snapshot() []*Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sessions := make([]*Session, 0, len(m.sessions))
|
||||
for _, session := range m.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
156
transports/bifrost-http/websocket/session_test.go
Normal file
156
transports/bifrost-http/websocket/session_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
)
|
||||
|
||||
func TestSessionManagerCreateAndGet(t *testing.T) {
|
||||
manager := NewSessionManager(2)
|
||||
conn := newTestConn()
|
||||
|
||||
session, err := manager.Create(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Create() unexpected error: %v", err)
|
||||
}
|
||||
if session == nil {
|
||||
t.Fatal("Create() returned nil session")
|
||||
}
|
||||
if got := manager.Get(conn); got != session {
|
||||
t.Fatal("Get() did not return the created session")
|
||||
}
|
||||
if got := manager.Count(); got != 1 {
|
||||
t.Fatalf("Count() = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagerConnectionLimit(t *testing.T) {
|
||||
manager := NewSessionManager(1)
|
||||
|
||||
if _, err := manager.Create(newTestConn()); err != nil {
|
||||
t.Fatalf("first Create() unexpected error: %v", err)
|
||||
}
|
||||
if _, err := manager.Create(newTestConn()); err != ErrConnectionLimitReached {
|
||||
t.Fatalf("second Create() error = %v, want %v", err, ErrConnectionLimitReached)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagerRemove(t *testing.T) {
|
||||
manager := NewSessionManager(2)
|
||||
conn := newTestConn()
|
||||
|
||||
session, err := manager.Create(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Create() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
manager.Remove(conn)
|
||||
|
||||
if got := manager.Get(conn); got != nil {
|
||||
t.Fatal("Get() should return nil after Remove()")
|
||||
}
|
||||
if got := manager.Count(); got != 0 {
|
||||
t.Fatalf("Count() = %d, want 0", got)
|
||||
}
|
||||
if !session.closed {
|
||||
t.Fatal("expected removed session to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLastResponseID(t *testing.T) {
|
||||
session := NewSession(newTestConn())
|
||||
session.SetLastResponseID("resp-123")
|
||||
|
||||
if got := session.LastResponseID(); got != "resp-123" {
|
||||
t.Fatalf("LastResponseID() = %q, want %q", got, "resp-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagerCloseAll(t *testing.T) {
|
||||
manager := NewSessionManager(4)
|
||||
connA := newTestConn()
|
||||
connB := newTestConn()
|
||||
|
||||
sessionA, err := manager.Create(connA)
|
||||
if err != nil {
|
||||
t.Fatalf("Create(connA) unexpected error: %v", err)
|
||||
}
|
||||
sessionB, err := manager.Create(connB)
|
||||
if err != nil {
|
||||
t.Fatalf("Create(connB) unexpected error: %v", err)
|
||||
}
|
||||
|
||||
manager.CloseAll()
|
||||
|
||||
if got := manager.Count(); got != 0 {
|
||||
t.Fatalf("Count() = %d, want 0", got)
|
||||
}
|
||||
if !sessionA.closed || !sessionB.closed {
|
||||
t.Fatal("expected all sessions to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRealtimeState(t *testing.T) {
|
||||
session := NewSession(newTestConn())
|
||||
if session.ID() == "" {
|
||||
t.Fatal("expected session ID to be populated")
|
||||
}
|
||||
|
||||
session.SetProviderSessionID("provider-session-1")
|
||||
if got := session.ProviderSessionID(); got != "provider-session-1" {
|
||||
t.Fatalf("ProviderSessionID() = %q, want %q", got, "provider-session-1")
|
||||
}
|
||||
|
||||
session.AppendRealtimeOutputText("hello")
|
||||
session.AppendRealtimeOutputText(" world")
|
||||
if got := session.ConsumeRealtimeOutputText(); got != "hello world" {
|
||||
t.Fatalf("ConsumeRealtimeOutputText() = %q, want %q", got, "hello world")
|
||||
}
|
||||
if got := session.ConsumeRealtimeOutputText(); got != "" {
|
||||
t.Fatalf("ConsumeRealtimeOutputText() after clear = %q, want empty string", got)
|
||||
}
|
||||
|
||||
session.AddRealtimeInput("hello", `{"type":"conversation.item.create","item":{"role":"user"}}`)
|
||||
session.AddRealtimeToolOutput("tool result", `{"type":"conversation.item.create","item":{"type":"function_call_output"}}`)
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
if len(turnInputs) != 2 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 2", len(turnInputs))
|
||||
}
|
||||
if turnInputs[0].Role != "user" || turnInputs[0].Summary != "hello" {
|
||||
t.Fatalf("turnInputs[0] = %+v, want user hello", turnInputs[0])
|
||||
}
|
||||
if turnInputs[1].Role != "tool" || turnInputs[1].Summary != "tool result" {
|
||||
t.Fatalf("turnInputs[1] = %+v, want tool result", turnInputs[1])
|
||||
}
|
||||
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) after clear = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRecordRealtimeInputUpdatesPendingItemAndIgnoresConsumedLateUpdate(t *testing.T) {
|
||||
session := NewSession(newTestConn())
|
||||
|
||||
session.RecordRealtimeInput("item_1", "[Audio transcription unavailable]", `{"type":"conversation.item.done","item":{"id":"item_1"}}`)
|
||||
session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`)
|
||||
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
if len(turnInputs) != 1 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 1", len(turnInputs))
|
||||
}
|
||||
if turnInputs[0].ItemID != "item_1" {
|
||||
t.Fatalf("turnInputs[0].ItemID = %q, want %q", turnInputs[0].ItemID, "item_1")
|
||||
}
|
||||
if turnInputs[0].Summary != "Hello." {
|
||||
t.Fatalf("turnInputs[0].Summary = %q, want %q", turnInputs[0].Summary, "Hello.")
|
||||
}
|
||||
|
||||
session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`)
|
||||
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) after late consumed update = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func newTestConn() *ws.Conn {
|
||||
return &ws.Conn{}
|
||||
}
|
||||
0
transports/changelog.md
Normal file
0
transports/changelog.md
Normal file
4011
transports/config.schema.json
Normal file
4011
transports/config.schema.json
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user