first commit
This commit is contained in:
105
transports/bifrost-http/lib/account.go
Normal file
105
transports/bifrost-http/lib/account.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// Package lib provides core functionality for the Bifrost HTTP service,
|
||||
// including context propagation, header management, and integration with monitoring systems.
|
||||
package lib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BaseAccount implements the Account interface for Bifrost.
|
||||
// It manages provider configurations using a in-memory store for persistent storage.
|
||||
// All data processing (environment variables, key configs) is done upfront in the store.
|
||||
type BaseAccount struct {
|
||||
store *Config // store for in-memory configuration
|
||||
}
|
||||
|
||||
// NewBaseAccount creates a new BaseAccount with the given store
|
||||
func NewBaseAccount(store *Config) *BaseAccount {
|
||||
return &BaseAccount{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfiguredProviders returns a list of all configured providers.
|
||||
// Implements the Account interface.
|
||||
func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
|
||||
if baseAccount.store == nil {
|
||||
return nil, fmt.Errorf("store not initialized")
|
||||
}
|
||||
return baseAccount.store.GetAllProviders()
|
||||
}
|
||||
|
||||
// GetKeysForProvider returns the API keys configured for a specific provider.
|
||||
// Keys are already processed (environment variables resolved) by the store.
|
||||
// Implements the Account interface.
|
||||
func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) {
|
||||
if baseAccount.store == nil {
|
||||
return nil, fmt.Errorf("store not initialized")
|
||||
}
|
||||
config, err := baseAccount.store.GetProviderConfigRaw(providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys := config.Keys
|
||||
if v := ctx.Value(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys); v != nil {
|
||||
if includeOnlyKeys, ok := v.([]string); ok {
|
||||
if len(includeOnlyKeys) == 0 {
|
||||
// header present but empty means "no keys allowed"
|
||||
keys = nil
|
||||
} else {
|
||||
set := make(map[string]struct{}, len(includeOnlyKeys))
|
||||
for _, id := range includeOnlyKeys {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
filtered := make([]schemas.Key, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if _, ok := set[key.ID]; ok {
|
||||
filtered = append(filtered, key)
|
||||
}
|
||||
}
|
||||
keys = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetConfigForProvider returns the complete configuration for a specific provider.
|
||||
// Configuration is already fully processed (environment variables, key configs) by the store.
|
||||
// Implements the Account interface.
|
||||
func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) {
|
||||
if baseAccount.store == nil {
|
||||
return nil, fmt.Errorf("store not initialized")
|
||||
}
|
||||
config, err := baseAccount.store.GetProviderConfigRaw(providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerConfig := &schemas.ProviderConfig{}
|
||||
if config.ProxyConfig != nil {
|
||||
providerConfig.ProxyConfig = config.ProxyConfig
|
||||
}
|
||||
if config.NetworkConfig != nil {
|
||||
providerConfig.NetworkConfig = *config.NetworkConfig
|
||||
} else {
|
||||
providerConfig.NetworkConfig = schemas.DefaultNetworkConfig
|
||||
}
|
||||
if config.ConcurrencyAndBufferSize != nil {
|
||||
providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize
|
||||
} else {
|
||||
providerConfig.ConcurrencyAndBufferSize = schemas.DefaultConcurrencyAndBufferSize
|
||||
}
|
||||
providerConfig.SendBackRawRequest = config.SendBackRawRequest
|
||||
providerConfig.SendBackRawResponse = config.SendBackRawResponse
|
||||
providerConfig.StoreRawRequestResponse = config.StoreRawRequestResponse
|
||||
if config.CustomProviderConfig != nil {
|
||||
providerConfig.CustomProviderConfig = config.CustomProviderConfig
|
||||
}
|
||||
if config.OpenAIConfig != nil {
|
||||
providerConfig.OpenAIConfig = config.OpenAIConfig
|
||||
}
|
||||
return providerConfig, nil
|
||||
}
|
||||
4310
transports/bifrost-http/lib/config.go
Normal file
4310
transports/bifrost-http/lib/config.go
Normal file
File diff suppressed because it is too large
Load Diff
17838
transports/bifrost-http/lib/config_test.go
Normal file
17838
transports/bifrost-http/lib/config_test.go
Normal file
File diff suppressed because it is too large
Load Diff
644
transports/bifrost-http/lib/ctx.go
Normal file
644
transports/bifrost-http/lib/ctx.go
Normal file
@@ -0,0 +1,644 @@
|
||||
// Package lib provides core functionality for the Bifrost HTTP service,
|
||||
// including context propagation, header management, and integration with monitoring systems.
|
||||
//
|
||||
// This package handles the conversion of FastHTTP request contexts to Bifrost contexts,
|
||||
// ensuring that important metadata and tracking information is preserved across the system.
|
||||
// It supports propagation of both Prometheus metrics and Maxim tracing data through HTTP headers.
|
||||
package lib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/plugins/governance"
|
||||
"github.com/maximhq/bifrost/plugins/maxim"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// FastHTTPUserValueBifrostContext stores the active *schemas.BifrostContext on fasthttp.RequestCtx.
|
||||
// This allows transport middleware and request handlers to share the same context instance.
|
||||
FastHTTPUserValueBifrostContext = "__bifrost_context"
|
||||
// FastHTTPUserValueBifrostCancel stores the cancel func for the active shared Bifrost context.
|
||||
FastHTTPUserValueBifrostCancel = "__bifrost_context_cancel"
|
||||
// FastHTTPUserValueLargeResponseMode marks requests that streamed a large response body.
|
||||
// It is used by transport middleware to avoid re-buffering response bodies for post-hooks.
|
||||
FastHTTPUserValueLargeResponseMode = "__bifrost_large_response_mode"
|
||||
)
|
||||
|
||||
// ParseSessionIDFromBaggage extracts the session-id baggage member value.
|
||||
// It supports simple W3C baggage parsing sufficient for log grouping.
|
||||
func ParseSessionIDFromBaggage(header string) string {
|
||||
for _, member := range strings.Split(header, ",") {
|
||||
member = strings.TrimSpace(member)
|
||||
if member == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(member, ";", 2)
|
||||
kv := strings.SplitN(strings.TrimSpace(parts[0]), "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.ToLower(strings.TrimSpace(kv[0]))
|
||||
value := strings.TrimSpace(kv[1])
|
||||
if key != "session-id" || value == "" {
|
||||
continue
|
||||
}
|
||||
if len(value) > 255 {
|
||||
if logger != nil {
|
||||
logger.Warn("session-id exceeds 255 chars, ignoring: length=%d, prefix=%s", len(value), value[:255])
|
||||
}
|
||||
continue
|
||||
}
|
||||
return value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context,
|
||||
// preserving important header values for monitoring and tracing purposes.
|
||||
//
|
||||
// The function processes several types of special headers:
|
||||
// 1. Prometheus Headers (x-bf-prom-*):
|
||||
// - All headers prefixed with 'x-bf-prom-' are copied to the context
|
||||
// - The prefix is stripped and the remainder becomes the context key
|
||||
// - Example: 'x-bf-prom-latency' becomes 'latency' in the context
|
||||
//
|
||||
// 2. Maxim Tracing Headers (x-bf-maxim-*):
|
||||
// - Specifically handles 'x-bf-maxim-traceID' and 'x-bf-maxim-generationID'
|
||||
// - These headers enable trace correlation across service boundaries
|
||||
// - Values are stored using Maxim's context keys for consistency
|
||||
//
|
||||
// 3. MCP Headers (x-bf-mcp-*):
|
||||
// - Specifically handles 'x-bf-mcp-include-clients' and 'x-bf-mcp-include-tools' (include-only filtering)
|
||||
// - These headers enable MCP client and tool filtering
|
||||
// - Values are stored using MCP context keys for consistency
|
||||
//
|
||||
// 4. Governance Headers:
|
||||
// - x-bf-vk: Virtual key for governance (required for governance to work)
|
||||
//
|
||||
// 5. API Key Headers:
|
||||
// - Authorization: Bearer token format only (e.g., "Bearer sk-...") - OpenAI style
|
||||
// - x-api-key: Direct API key value - Anthropic style
|
||||
// - x-goog-api-key: Direct API key value - Google Gemini style
|
||||
// - x-bf-api-key references a stored API key name rather than the raw secret.
|
||||
// - Keys are extracted and stored in the context using schemas.BifrostContextKey
|
||||
// - This enables explicit key usage for requests via headers
|
||||
//
|
||||
// 6. Cancellable Context:
|
||||
// - Creates a cancellable context that can be used to cancel upstream requests when clients disconnect
|
||||
// - This is critical for streaming requests where write errors indicate client disconnects
|
||||
// - Also useful for non-streaming requests to allow provider-level cancellation
|
||||
//
|
||||
// 7. Extra Headers (x-bf-eh-*):
|
||||
// - Any header starting with 'x-bf-eh-' is collected and added to the map stored under schemas.BifrostContextKeyExtraHeaders
|
||||
// - The prefix is stripped, the remainder is lower-cased, and duplicate names append values
|
||||
// - This allows callers to send arbitrary context metadata without needing to extend the public schema
|
||||
//
|
||||
// 8. Session Stickiness Headers:
|
||||
// - x-bf-session-id: Session identifier for key binding (reuse same key across requests)
|
||||
// - x-bf-session-ttl: Per-request TTL override (duration string e.g. "30m" or seconds integer)
|
||||
//
|
||||
// 9. Raw Capture Headers (per-request override of provider config; accepts "true" or "false"):
|
||||
// - x-bf-send-back-raw-request: include raw provider request in the BifrostResponse returned to the caller
|
||||
// - x-bf-send-back-raw-response: include raw provider response in the BifrostResponse returned to the caller
|
||||
// - x-bf-store-raw-request-response: capture raw request/response for logging only (stripped from client response)
|
||||
|
||||
// Parameters:
|
||||
// - ctx: The FastHTTP request context containing the original headers
|
||||
// - allowDirectKeys: Whether to allow direct API key usage from headers
|
||||
//
|
||||
// Returns:
|
||||
// - *context.Context: A new cancellable context.Context containing the propagated values
|
||||
// - context.CancelFunc: Function to cancel the context (should be called when request completes)
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// fastCtx := &fasthttp.RequestCtx{...}
|
||||
// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, true, nil)
|
||||
// defer cancel() // Ensure cleanup
|
||||
// // bifrostCtx now contains propagated header values including Prometheus metrics,
|
||||
// // Maxim tracing data, MCP filters, governance keys, API keys, cache settings,
|
||||
// // session stickiness, and extra headers
|
||||
|
||||
func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher, mcpHeaderCombinedAllowlist schemas.WhiteList) (*schemas.BifrostContext, context.CancelFunc) {
|
||||
// Reuse a shared request-scoped context when available.
|
||||
var bifrostCtx *schemas.BifrostContext
|
||||
var cancel context.CancelFunc
|
||||
if existing, ok := ctx.UserValue(FastHTTPUserValueBifrostContext).(*schemas.BifrostContext); ok && existing != nil {
|
||||
if existingCancel, ok := ctx.UserValue(FastHTTPUserValueBifrostCancel).(context.CancelFunc); ok && existingCancel != nil {
|
||||
bifrostCtx = existing
|
||||
cancel = existingCancel
|
||||
} else {
|
||||
// Create one cancellable child context and promote it as the shared context.
|
||||
bifrostCtx, cancel = schemas.NewBifrostContextWithCancel(existing)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostContext, bifrostCtx)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostCancel, cancel)
|
||||
}
|
||||
}
|
||||
if bifrostCtx == nil {
|
||||
// Create cancellable context for requests that don't have a shared context yet.
|
||||
parent := context.Context(ctx)
|
||||
func() {
|
||||
// Zero-value fasthttp.RequestCtx can panic on Done(); fall back safely.
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
}()
|
||||
_ = ctx.Done()
|
||||
}()
|
||||
bifrostCtx, cancel = schemas.NewBifrostContextWithCancel(parent)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostContext, bifrostCtx)
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostCancel, cancel)
|
||||
}
|
||||
|
||||
// Preserve existing request-id if already present on the shared context.
|
||||
if existingRequestID, ok := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string); !ok || existingRequestID == "" {
|
||||
// First, check if x-request-id header exists
|
||||
requestID := string(ctx.Request.Header.Peek("x-request-id"))
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
|
||||
}
|
||||
// Populating all user values from the request context
|
||||
ctx.VisitUserValuesAll(func(key, value any) {
|
||||
bifrostCtx.SetValue(key, value)
|
||||
})
|
||||
// Initialize tags map for collecting maxim tags
|
||||
maximTags := make(map[string]string)
|
||||
// Initialize extra headers map for headers prefixed with x-bf-eh-
|
||||
extraHeaders := make(map[string][]string)
|
||||
// Initialize extra headers map for headers in the mcp header combined allowlist
|
||||
mcpExtraHeaders := make(map[string][]string)
|
||||
// Security denylist of header names that should never be accepted (case-insensitive)
|
||||
// This denylist is always enforced regardless of user configuration
|
||||
securityDenylist := map[string]bool{
|
||||
"proxy-authorization": true,
|
||||
"cookie": true,
|
||||
"host": true,
|
||||
"content-length": true,
|
||||
"connection": true,
|
||||
"transfer-encoding": true,
|
||||
|
||||
// prevent auth/key overrides via x-bf-eh-*
|
||||
"x-api-key": true,
|
||||
"x-goog-api-key": true,
|
||||
"x-bf-api-key": true,
|
||||
"x-bf-api-key-id": true,
|
||||
"x-bf-vk": true,
|
||||
}
|
||||
|
||||
// Debug: Log header matcher state
|
||||
if logger != nil {
|
||||
if matcher != nil {
|
||||
logger.Debug("headerMatcher hasAllowlist=%v, hasDenylist=%v", matcher.HasAllowlist(), matcher.hasDenylist)
|
||||
} else {
|
||||
logger.Debug("headerMatcher is nil (allow all)")
|
||||
}
|
||||
}
|
||||
|
||||
// Then process other headers
|
||||
ctx.Request.Header.All()(func(key, value []byte) bool {
|
||||
keyStr := strings.ToLower(string(key))
|
||||
if keyStr == "baggage" {
|
||||
if sessionID := ParseSessionIDFromBaggage(string(value)); sessionID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-prom-"); ok {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
return true
|
||||
}
|
||||
// Checking for maxim headers
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-maxim-"); ok {
|
||||
switch labelName {
|
||||
case string(maxim.GenerationIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.TraceIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.SessionIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.TraceNameKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.GenerationNameKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
case string(maxim.LogRepoIDKey):
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
|
||||
default:
|
||||
// apart from these all headers starting with x-bf-maxim- are keys for tags
|
||||
// collect them in the maximTags map
|
||||
maximTags[labelName] = string(value)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// MCP control headers (include-only filtering)
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-mcp-"); ok {
|
||||
switch labelName {
|
||||
case "include-clients":
|
||||
fallthrough
|
||||
case "include-tools":
|
||||
// Parse comma-separated values into []string
|
||||
valueStr := string(value)
|
||||
var parsedValues []string
|
||||
if valueStr != "" {
|
||||
// Split by comma and trim whitespace
|
||||
for _, v := range strings.Split(valueStr, ",") {
|
||||
if trimmed := strings.TrimSpace(v); trimmed != "" {
|
||||
parsedValues = append(parsedValues, trimmed)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
parsedValues = []string{""}
|
||||
}
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey("mcp-"+labelName), parsedValues)
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Handle virtual key header (x-bf-vk, authorization, x-api-key, x-goog-api-key headers)
|
||||
if keyStr == string(schemas.BifrostContextKeyVirtualKey) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
|
||||
return true
|
||||
}
|
||||
if keyStr == "authorization" {
|
||||
valueStr := string(value)
|
||||
// Only accept Bearer token format: "Bearer ..."
|
||||
if strings.HasPrefix(strings.ToLower(valueStr), "bearer ") {
|
||||
authHeaderValue := strings.TrimSpace(valueStr[7:]) // Remove "Bearer " prefix
|
||||
if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, authHeaderValue)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if keyStr == "x-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-goog-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-api-key" {
|
||||
if keyName := strings.TrimSpace(string(value)); keyName != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, keyName)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-api-key-id" {
|
||||
if keyID := strings.TrimSpace(string(value)); keyID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, keyID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Handle cache key header (x-bf-cache-key)
|
||||
if keyStr == "x-bf-cache-key" {
|
||||
bifrostCtx.SetValue(semanticcache.CacheKey, string(value))
|
||||
return true
|
||||
}
|
||||
// Handle cache TTL header (x-bf-cache-ttl)
|
||||
if keyStr == "x-bf-cache-ttl" {
|
||||
valueStr := string(value)
|
||||
var ttlDuration time.Duration
|
||||
var err error
|
||||
|
||||
// First try to parse as duration (e.g., "30s", "5m", "1h")
|
||||
if ttlDuration, err = time.ParseDuration(valueStr); err != nil {
|
||||
// If that fails, try to parse as plain number and treat as seconds
|
||||
if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 {
|
||||
ttlDuration = time.Duration(seconds) * time.Second
|
||||
err = nil // Reset error since we successfully parsed as seconds
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
bifrostCtx.SetValue(semanticcache.CacheTTLKey, ttlDuration)
|
||||
}
|
||||
// If both parsing attempts fail, we silently ignore the header and use default TTL
|
||||
return true
|
||||
}
|
||||
// Cache threshold header
|
||||
if keyStr == "x-bf-cache-threshold" {
|
||||
threshold, err := strconv.ParseFloat(string(value), 64)
|
||||
if err == nil {
|
||||
// Clamp threshold to the inclusive range [0.0, 1.0]
|
||||
if threshold < 0.0 {
|
||||
threshold = 0.0
|
||||
} else if threshold > 1.0 {
|
||||
threshold = 1.0
|
||||
}
|
||||
bifrostCtx.SetValue(semanticcache.CacheThresholdKey, threshold)
|
||||
}
|
||||
// If parsing fails, silently ignore the header (no context value set)
|
||||
return true
|
||||
}
|
||||
// Cache type header
|
||||
if keyStr == "x-bf-cache-type" {
|
||||
bifrostCtx.SetValue(semanticcache.CacheTypeKey, semanticcache.CacheType(string(value)))
|
||||
return true
|
||||
}
|
||||
// Cache no store header
|
||||
if keyStr == "x-bf-cache-no-store" {
|
||||
if valueStr := string(value); valueStr == "true" {
|
||||
bifrostCtx.SetValue(semanticcache.CacheNoStoreKey, true)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Session stickiness: session ID for key binding
|
||||
if keyStr == "x-bf-session-id" {
|
||||
if valueStr := strings.TrimSpace(string(value)); valueStr != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySessionID, valueStr)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Session stickiness: per-request TTL override (duration string or seconds integer)
|
||||
if keyStr == "x-bf-session-ttl" {
|
||||
valueStr := strings.TrimSpace(string(value))
|
||||
var ttlDuration time.Duration
|
||||
var err error
|
||||
if ttlDuration, err = time.ParseDuration(valueStr); err != nil {
|
||||
if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 {
|
||||
ttlDuration = time.Duration(seconds) * time.Second
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err == nil && ttlDuration > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySessionTTL, ttlDuration)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-eh-"); ok {
|
||||
// Skip empty header names after prefix removal
|
||||
if labelName == "" {
|
||||
return true
|
||||
}
|
||||
// Normalize header name to lowercase
|
||||
labelName = strings.ToLower(labelName)
|
||||
// Validate against security denylist (always enforced)
|
||||
if securityDenylist[labelName] {
|
||||
return true
|
||||
}
|
||||
// Apply configurable header filter
|
||||
if !matcher.ShouldAllow(labelName) {
|
||||
return true
|
||||
}
|
||||
// Append header value (allow multiple values for the same header)
|
||||
extraHeaders[labelName] = append(extraHeaders[labelName], string(value))
|
||||
return true
|
||||
}
|
||||
// Direct header forwarding: when allowlist is configured, any header explicitly
|
||||
// in the allowlist can be forwarded directly without the x-bf-eh- prefix.
|
||||
// This enables forwarding arbitrary headers like "anthropic-beta" directly.
|
||||
// Only applies when allowlist is non-empty (backward compatible).
|
||||
if matcher.HasAllowlist() {
|
||||
if matcher.MatchesAllow(keyStr) {
|
||||
// Skip reserved x-bf-* headers (handled separately)
|
||||
if strings.HasPrefix(keyStr, "x-bf-") {
|
||||
return true
|
||||
}
|
||||
// Validate against security denylist (always enforced)
|
||||
if securityDenylist[keyStr] {
|
||||
return true
|
||||
}
|
||||
// Check denylist
|
||||
if matcher.MatchesDeny(keyStr) {
|
||||
return true
|
||||
}
|
||||
// Forward the header directly with its original name
|
||||
if logger != nil {
|
||||
logger.Debug("forwarding header via allowlist: %s", keyStr)
|
||||
}
|
||||
extraHeaders[keyStr] = append(extraHeaders[keyStr], string(value))
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Handle MCP extra headers
|
||||
if mcpHeaderCombinedAllowlist.IsAllowed(keyStr) {
|
||||
mcpExtraHeaders[keyStr] = append(mcpExtraHeaders[keyStr], string(value))
|
||||
return true
|
||||
}
|
||||
// Raw capture headers — all three support "true"/"false" to fully override the
|
||||
// provider-level config for this request.
|
||||
if keyStr == "x-bf-send-back-raw-request" {
|
||||
if b, err := strconv.ParseBool(string(value)); err == nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawRequest, b)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-send-back-raw-response" {
|
||||
if b, err := strconv.ParseBool(string(value)); err == nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawResponse, b)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if keyStr == "x-bf-store-raw-request-response" {
|
||||
if b, err := strconv.ParseBool(string(value)); err == nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyStoreRawRequestResponse, b)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Parent request ID header (for linking MCP tool calls to parent LLM requests)
|
||||
if keyStr == "x-bf-parent-request-id" {
|
||||
if valueStr := strings.TrimSpace(string(value)); valueStr != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, valueStr)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Add passthrough extra params header support
|
||||
if keyStr == "x-bf-passthrough-extra-params" {
|
||||
if valueStr := string(value); valueStr == "true" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Compat header: per-request override of compat plugin settings.
|
||||
// Accepts: "true" (enable all), JSON array of feature names, or ["*"] (enable all).
|
||||
// An empty array [] or absent header means no overrides.
|
||||
if keyStr == "x-bf-compat" {
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertTextToChat)
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertChatToResponses)
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldDropParams)
|
||||
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldConvertParams)
|
||||
valueStr := strings.TrimSpace(string(value))
|
||||
if valueStr == "true" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
|
||||
} else if strings.HasPrefix(valueStr, "[") {
|
||||
var features []string
|
||||
if err := json.Unmarshal([]byte(valueStr), &features); err == nil {
|
||||
if len(features) == 1 && features[0] == "*" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
|
||||
} else {
|
||||
for _, f := range features {
|
||||
switch f {
|
||||
case "convert_text_to_chat":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
|
||||
case "convert_chat_to_responses":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
|
||||
case "should_drop_params":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
|
||||
case "should_convert_params":
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Store the collected maxim tags in the context
|
||||
if len(maximTags) > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKey(maxim.TagsKey), maximTags)
|
||||
}
|
||||
|
||||
// Store collected extra headers in the context if any were found
|
||||
if len(extraHeaders) > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
|
||||
}
|
||||
|
||||
// Store collected MCP extra headers in the context if any were found
|
||||
if len(mcpExtraHeaders) > 0 {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPExtraHeaders, mcpExtraHeaders)
|
||||
}
|
||||
|
||||
// Collect all request headers for downstream use (e.g., governance required headers check)
|
||||
// Keys are lowercased for case-insensitive lookup
|
||||
allHeaders := make(map[string]string)
|
||||
ctx.Request.Header.All()(func(key, value []byte) bool {
|
||||
allHeaders[strings.ToLower(string(key))] = string(value)
|
||||
return true
|
||||
})
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyRequestHeaders, allHeaders)
|
||||
|
||||
// Extract per-user MCP OAuth user identifier from X-Bf-User-Id header
|
||||
if mcpUserID := string(ctx.Request.Header.Peek("X-Bf-User-Id")); mcpUserID != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserID, mcpUserID)
|
||||
}
|
||||
|
||||
// Build and set OAuth redirect URI for per-user OAuth flows
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
if host != "" {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyOAuthRedirectURI, fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host))
|
||||
}
|
||||
|
||||
if allowDirectKeys {
|
||||
// Extract API key from Authorization header (Bearer format), x-api-key, or x-goog-api-key header
|
||||
var apiKey string
|
||||
|
||||
// TODO: fix plugin data leak
|
||||
// Check Authorization header (Bearer format only - OpenAI style)
|
||||
authHeader := string(ctx.Request.Header.Peek("Authorization"))
|
||||
if authHeader != "" {
|
||||
// Only accept Bearer token format: "Bearer ..."
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix
|
||||
if authHeaderValue != "" && !strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) {
|
||||
apiKey = authHeaderValue
|
||||
}
|
||||
} else {
|
||||
apiKey = authHeader
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
// Check x-api-key (Anthropic style) header if no valid Authorization header found
|
||||
xAPIKey := string(ctx.Request.Header.Peek("x-api-key"))
|
||||
if xAPIKey != "" && !strings.HasPrefix(strings.ToLower(xAPIKey), governance.VirtualKeyPrefix) {
|
||||
apiKey = strings.TrimSpace(xAPIKey)
|
||||
} else {
|
||||
// Check x-goog-api-key (Google Gemini style) header if no valid Authorization header found
|
||||
xGoogleAPIKey := string(ctx.Request.Header.Peek("x-goog-api-key"))
|
||||
if xGoogleAPIKey != "" && !strings.HasPrefix(strings.ToLower(xGoogleAPIKey), governance.VirtualKeyPrefix) {
|
||||
apiKey = strings.TrimSpace(xGoogleAPIKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we found an API key, create a Key object and store it in context
|
||||
if apiKey != "" {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided", // Identifier for header-provided keys
|
||||
Value: *schemas.NewEnvVar(apiKey),
|
||||
Models: schemas.WhiteList{"*"}, // Allow all models
|
||||
Weight: 1.0, // Default weight
|
||||
}
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
return bifrostCtx, cancel
|
||||
}
|
||||
|
||||
// BuildHTTPRequestFromFastHTTP creates an HTTPRequest from fasthttp context for streaming handlers.
|
||||
// The returned request should be released with schemas.ReleaseHTTPRequest when done.
|
||||
// Note: Body is not copied for streaming (body was already consumed for the request).
|
||||
func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest {
|
||||
req := schemas.AcquireHTTPRequest()
|
||||
req.Method = string(ctx.Method())
|
||||
req.Path = string(ctx.Path())
|
||||
|
||||
// Copy headers
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
req.Headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
// Copy query params
|
||||
for key, value := range ctx.Request.URI().QueryArgs().All() {
|
||||
req.Query[string(key)] = string(value)
|
||||
}
|
||||
|
||||
// Copy path parameters from user values
|
||||
ctx.VisitUserValuesAll(func(key, value any) {
|
||||
keyStr, keyIsString := key.(string)
|
||||
valueStr, valueIsString := value.(string)
|
||||
if !keyIsString || !valueIsString {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(keyStr, "bifrost-") ||
|
||||
keyStr == "BifrostContextKeyRequestID" ||
|
||||
keyStr == "trace_id" ||
|
||||
keyStr == "span_id" {
|
||||
return
|
||||
}
|
||||
req.PathParams[keyStr] = valueStr
|
||||
})
|
||||
|
||||
// Note: Body not copied - for streaming, body was already consumed
|
||||
return req
|
||||
}
|
||||
|
||||
// BuildHTTPResponseFromFastHTTP creates an HTTPResponse snapshot from fasthttp context.
|
||||
// Only captures status code and headers — body is skipped because for streaming
|
||||
// responses it is an active io.Reader that cannot be materialized.
|
||||
// The returned response should be released with schemas.ReleaseHTTPResponse when done.
|
||||
func BuildHTTPResponseFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPResponse {
|
||||
resp := schemas.AcquireHTTPResponse()
|
||||
resp.StatusCode = ctx.Response.StatusCode()
|
||||
for key, value := range ctx.Response.Header.All() {
|
||||
resp.Headers[string(key)] = string(value)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
275
transports/bifrost-http/lib/ctx_test.go
Normal file
275
transports/bifrost-http/lib/ctx_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestParseSessionIDFromBaggage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{name: "single member", header: "session-id=abc", want: "abc"},
|
||||
{name: "multiple members", header: "foo=bar, session-id=abc, baz=qux", want: "abc"},
|
||||
{name: "member with properties", header: "session-id=abc;ttl=60", want: "abc"},
|
||||
{name: "spaces preserved around parsing", header: " foo=bar , session-id = abc123 ;ttl=60 ", want: "abc123"},
|
||||
{name: "missing member", header: "foo=bar", want: ""},
|
||||
{name: "malformed ignored", header: "session-id, foo=bar", want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ParseSessionIDFromBaggage(tt.header); got != tt.want {
|
||||
t.Fatalf("ParseSessionIDFromBaggage(%q) = %q, want %q", tt.header, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
base := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
base.SetValue(schemas.BifrostContextKeyRequestID, "req-shared")
|
||||
ctx.SetUserValue(FastHTTPUserValueBifrostContext, base)
|
||||
|
||||
converted, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
if converted == nil {
|
||||
t.Fatal("expected non-nil converted context")
|
||||
}
|
||||
if got, _ := converted.Value(schemas.BifrostContextKeyRequestID).(string); got != "req-shared" {
|
||||
t.Fatalf("expected converted context to preserve parent values, got request-id=%q", got)
|
||||
}
|
||||
if stored, ok := ctx.UserValue(FastHTTPUserValueBifrostContext).(*schemas.BifrostContext); !ok || stored == nil {
|
||||
t.Fatal("expected shared context pointer to be stored on fasthttp user values")
|
||||
}
|
||||
if ctx.UserValue(FastHTTPUserValueBifrostCancel) == nil {
|
||||
t.Fatal("expected shared cancel function to be stored on fasthttp user values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_SecondCallReturnsSameSharedContext(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
first, cancelFirst := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancelFirst()
|
||||
if first == nil {
|
||||
t.Fatal("expected first context to be non-nil")
|
||||
}
|
||||
|
||||
second, cancelSecond := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancelSecond()
|
||||
if second == nil {
|
||||
t.Fatal("expected second context to be non-nil")
|
||||
}
|
||||
if first != second {
|
||||
t.Fatal("expected ConvertToBifrostContext to reuse the shared context on repeated calls")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked verifies that
|
||||
// even with a "*" allowlist (allow all), the hardcoded security denylist in
|
||||
// ConvertToBifrostContext still blocks security-sensitive headers.
|
||||
func TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// x-bf-eh-* prefixed headers
|
||||
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-cookie", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-x-api-key", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-host", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-connection", "should-be-blocked")
|
||||
ctx.Request.Header.Set("x-bf-eh-proxy-authorization", "should-be-blocked")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
// custom-header should be forwarded
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded via x-bf-eh- prefix")
|
||||
}
|
||||
|
||||
// Security headers should be blocked even with * allowlist
|
||||
securityHeaders := []string{"cookie", "x-api-key", "host", "connection", "proxy-authorization"}
|
||||
for _, h := range securityHeaders {
|
||||
if _, ok := extraHeaders[h]; ok {
|
||||
t.Errorf("expected security header %q to be blocked even with * allowlist", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked verifies
|
||||
// that direct header forwarding with "*" allowlist forwards non-security headers
|
||||
// but still blocks security headers.
|
||||
func TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// Direct headers (not prefixed with x-bf-eh-)
|
||||
ctx.Request.Header.Set("custom-header", "allowed-value")
|
||||
ctx.Request.Header.Set("anthropic-beta", "some-beta-feature")
|
||||
// Security headers sent directly — should be blocked
|
||||
ctx.Request.Header.Set("proxy-authorization", "should-be-blocked")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
// Direct non-security headers should be forwarded when allowlist has *
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded directly")
|
||||
}
|
||||
if _, ok := extraHeaders["anthropic-beta"]; !ok {
|
||||
t.Error("expected anthropic-beta to be forwarded directly")
|
||||
}
|
||||
|
||||
// Security headers should still be blocked in direct forwarding path
|
||||
directSecurityHeaders := []string{"proxy-authorization", "cookie", "host", "connection"}
|
||||
for _, h := range directSecurityHeaders {
|
||||
if _, ok := extraHeaders[h]; ok {
|
||||
t.Errorf("expected security header %q to be blocked in direct forwarding even with * allowlist", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_PrefixWildcardDirectForwarding verifies that
|
||||
// prefix wildcard patterns like "anthropic-*" work for direct header forwarding
|
||||
// (without x-bf-eh- prefix).
|
||||
func TestConvertToBifrostContext_PrefixWildcardDirectForwarding(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// Direct headers matching the wildcard pattern
|
||||
ctx.Request.Header.Set("anthropic-beta", "beta-value")
|
||||
ctx.Request.Header.Set("anthropic-version", "2024-01-01")
|
||||
// Header not matching the pattern
|
||||
ctx.Request.Header.Set("openai-version", "should-not-forward")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["anthropic-beta"]; !ok {
|
||||
t.Error("expected anthropic-beta to be forwarded directly via wildcard allowlist")
|
||||
}
|
||||
if _, ok := extraHeaders["anthropic-version"]; !ok {
|
||||
t.Error("expected anthropic-version to be forwarded directly via wildcard allowlist")
|
||||
}
|
||||
if _, ok := extraHeaders["openai-version"]; ok {
|
||||
t.Error("expected openai-version to NOT be forwarded (doesn't match anthropic-*)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_WildcardAllowlistFiltering verifies wildcard patterns
|
||||
// correctly filter headers via the x-bf-eh- prefix path.
|
||||
func TestConvertToBifrostContext_WildcardAllowlistFiltering(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("x-bf-eh-anthropic-beta", "beta-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-anthropic-version", "2024-01-01")
|
||||
ctx.Request.Header.Set("x-bf-eh-openai-version", "should-be-blocked")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["anthropic-beta"]; !ok {
|
||||
t.Error("expected anthropic-beta to be forwarded")
|
||||
}
|
||||
if _, ok := extraHeaders["anthropic-version"]; !ok {
|
||||
t.Error("expected anthropic-version to be forwarded")
|
||||
}
|
||||
if _, ok := extraHeaders["openai-version"]; ok {
|
||||
t.Error("expected openai-version to be blocked (not matching anthropic-*)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_WildcardDenylistBlocking verifies wildcard denylist
|
||||
// patterns block matching headers.
|
||||
func TestConvertToBifrostContext_WildcardDenylistBlocking(t *testing.T) {
|
||||
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-internal-*"},
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("x-bf-eh-x-internal-id", "blocked-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-x-internal-secret", "blocked-value")
|
||||
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["x-internal-id"]; ok {
|
||||
t.Error("expected x-internal-id to be blocked by denylist")
|
||||
}
|
||||
if _, ok := extraHeaders["x-internal-secret"]; ok {
|
||||
t.Error("expected x-internal-secret to be blocked by denylist")
|
||||
}
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToBifrostContext_NilMatcher verifies nil matcher allows all headers.
|
||||
func TestConvertToBifrostContext_NilMatcher(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
|
||||
|
||||
if _, ok := extraHeaders["custom-header"]; !ok {
|
||||
t.Error("expected custom-header to be forwarded with nil matcher")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_BaggageSessionIDSetsGrouping(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("baggage", "foo=bar, session-id=rt-123, baz=qux")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
if got, _ := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-123" {
|
||||
t.Fatalf("parent request id = %q, want %q", got, "rt-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToBifrostContext_EmptyBaggageSessionIDIgnored(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("baggage", "session-id= ")
|
||||
|
||||
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
|
||||
defer cancel()
|
||||
|
||||
if got := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID); got != nil {
|
||||
t.Fatalf("parent request id should be unset, got %#v", got)
|
||||
}
|
||||
}
|
||||
6
transports/bifrost-http/lib/errors.go
Normal file
6
transports/bifrost-http/lib/errors.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package lib
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrAlreadyExists = errors.New("already exists")
|
||||
136
transports/bifrost-http/lib/headermatcher.go
Normal file
136
transports/bifrost-http/lib/headermatcher.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
// HeaderMatchesPattern returns true if headerName matches the pattern.
|
||||
// Patterns support trailing wildcard: "anthropic-*" matches "anthropic-beta".
|
||||
// A bare "*" matches everything. All comparisons are case-insensitive.
|
||||
func HeaderMatchesPattern(pattern, headerName string) bool {
|
||||
pattern = strings.ToLower(strings.TrimSpace(pattern))
|
||||
headerName = strings.ToLower(strings.TrimSpace(headerName))
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
return strings.HasPrefix(headerName, pattern[:len(pattern)-1])
|
||||
}
|
||||
return pattern == headerName
|
||||
}
|
||||
|
||||
// HeaderMatcher holds precomputed header filter data for O(1) exact-match lookups
|
||||
// and fast prefix matching. Compiled once on config change, safe for concurrent reads.
|
||||
type HeaderMatcher struct {
|
||||
allowExact map[string]bool
|
||||
allowPrefixes []string // lowercased prefixes (without trailing *)
|
||||
allowAll bool
|
||||
hasAllowlist bool
|
||||
denyExact map[string]bool
|
||||
denyPrefixes []string
|
||||
denyAll bool
|
||||
hasDenylist bool
|
||||
}
|
||||
|
||||
// NewHeaderMatcher compiles a GlobalHeaderFilterConfig into an optimized HeaderMatcher.
|
||||
// Returns nil if config is nil (callers should treat nil as "allow all").
|
||||
func NewHeaderMatcher(config *configstoreTables.GlobalHeaderFilterConfig) *HeaderMatcher {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
m := &HeaderMatcher{
|
||||
allowExact: make(map[string]bool, len(config.Allowlist)),
|
||||
denyExact: make(map[string]bool, len(config.Denylist)),
|
||||
}
|
||||
for _, p := range config.Allowlist {
|
||||
lp := strings.ToLower(strings.TrimSpace(p))
|
||||
if lp == "" {
|
||||
continue
|
||||
}
|
||||
if lp == "*" {
|
||||
m.allowAll = true
|
||||
} else if strings.HasSuffix(lp, "*") {
|
||||
m.allowPrefixes = append(m.allowPrefixes, lp[:len(lp)-1])
|
||||
} else {
|
||||
m.allowExact[lp] = true
|
||||
}
|
||||
}
|
||||
for _, p := range config.Denylist {
|
||||
lp := strings.ToLower(strings.TrimSpace(p))
|
||||
if lp == "" {
|
||||
continue
|
||||
}
|
||||
if lp == "*" {
|
||||
m.denyAll = true
|
||||
} else if strings.HasSuffix(lp, "*") {
|
||||
m.denyPrefixes = append(m.denyPrefixes, lp[:len(lp)-1])
|
||||
} else {
|
||||
m.denyExact[lp] = true
|
||||
}
|
||||
}
|
||||
m.hasAllowlist = m.allowAll || len(m.allowExact) > 0 || len(m.allowPrefixes) > 0
|
||||
m.hasDenylist = m.denyAll || len(m.denyExact) > 0 || len(m.denyPrefixes) > 0
|
||||
return m
|
||||
}
|
||||
|
||||
// HasAllowlist returns true if the matcher has a non-empty allowlist.
|
||||
func (m *HeaderMatcher) HasAllowlist() bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
return m.hasAllowlist
|
||||
}
|
||||
|
||||
// MatchesAllow returns true if headerName matches any allowlist entry.
|
||||
// headerName must be lowercased by the caller.
|
||||
func (m *HeaderMatcher) MatchesAllow(headerName string) bool {
|
||||
if m.allowAll {
|
||||
return true
|
||||
}
|
||||
if m.allowExact[headerName] {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range m.allowPrefixes {
|
||||
if strings.HasPrefix(headerName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MatchesDeny returns true if headerName matches any denylist entry.
|
||||
// headerName must be lowercased by the caller.
|
||||
func (m *HeaderMatcher) MatchesDeny(headerName string) bool {
|
||||
if m.denyAll {
|
||||
return true
|
||||
}
|
||||
if m.denyExact[headerName] {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range m.denyPrefixes {
|
||||
if strings.HasPrefix(headerName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ShouldAllow determines if a header should be forwarded based on the
|
||||
// configurable header filter config (separate from the security denylist).
|
||||
// Returns true if the header passes both allowlist and denylist checks.
|
||||
// headerName is lowercased internally for case-insensitive matching.
|
||||
func (m *HeaderMatcher) ShouldAllow(headerName string) bool {
|
||||
if m == nil {
|
||||
return true
|
||||
}
|
||||
headerName = strings.ToLower(headerName)
|
||||
if m.hasAllowlist && !m.MatchesAllow(headerName) {
|
||||
return false
|
||||
}
|
||||
if m.hasDenylist && m.MatchesDeny(headerName) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
251
transports/bifrost-http/lib/headermatcher_test.go
Normal file
251
transports/bifrost-http/lib/headermatcher_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
|
||||
)
|
||||
|
||||
func TestHeaderMatchesPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
pattern string
|
||||
headerName string
|
||||
want bool
|
||||
}{
|
||||
// Exact match
|
||||
{"anthropic-beta", "anthropic-beta", true},
|
||||
{"anthropic-beta", "anthropic-alpha", false},
|
||||
|
||||
// Case insensitive exact match
|
||||
{"Anthropic-Beta", "anthropic-beta", true},
|
||||
{"anthropic-beta", "Anthropic-Beta", true},
|
||||
|
||||
// Star matches all
|
||||
{"*", "anything", true},
|
||||
{"*", "", true},
|
||||
|
||||
// Prefix wildcard
|
||||
{"anthropic-*", "anthropic-beta", true},
|
||||
{"anthropic-*", "anthropic-version", true},
|
||||
{"anthropic-*", "anthropic-", true},
|
||||
{"anthropic-*", "openai-version", false},
|
||||
{"anthropic-*", "anthropic", false},
|
||||
|
||||
// Case insensitive prefix wildcard
|
||||
{"Anthropic-*", "anthropic-beta", true},
|
||||
{"anthropic-*", "Anthropic-Beta", true},
|
||||
|
||||
// No match
|
||||
{"foo", "bar", false},
|
||||
{"", "foo", false},
|
||||
|
||||
// Pattern without wildcard doesn't prefix match
|
||||
{"anthropic-", "anthropic-beta", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern+"_"+tt.headerName, func(t *testing.T) {
|
||||
got := HeaderMatchesPattern(tt.pattern, tt.headerName)
|
||||
if got != tt.want {
|
||||
t.Errorf("HeaderMatchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.headerName, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHeaderMatcher_Nil(t *testing.T) {
|
||||
m := NewHeaderMatcher(nil)
|
||||
if m != nil {
|
||||
t.Fatal("expected nil matcher for nil config")
|
||||
}
|
||||
// nil matcher should allow everything
|
||||
if !m.ShouldAllow("anything") {
|
||||
t.Error("nil matcher should allow all headers")
|
||||
}
|
||||
if m.HasAllowlist() {
|
||||
t.Error("nil matcher should have no allowlist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHeaderMatcher_Empty(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{})
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil matcher for empty config")
|
||||
}
|
||||
if m.HasAllowlist() {
|
||||
t.Error("empty config should have no allowlist")
|
||||
}
|
||||
if !m.ShouldAllow("anything") {
|
||||
t.Error("empty config should allow all headers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_ExactAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-beta", "custom-id"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if !m.ShouldAllow("custom-id") {
|
||||
t.Error("should allow custom-id")
|
||||
}
|
||||
if m.ShouldAllow("openai-version") {
|
||||
t.Error("should not allow openai-version")
|
||||
}
|
||||
// Case insensitive
|
||||
if !m.ShouldAllow("Anthropic-Beta") {
|
||||
t.Error("should allow Anthropic-Beta (case insensitive)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_WildcardAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if !m.ShouldAllow("anthropic-version") {
|
||||
t.Error("should allow anthropic-version")
|
||||
}
|
||||
if m.ShouldAllow("openai-version") {
|
||||
t.Error("should not allow openai-version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_StarAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
})
|
||||
if !m.ShouldAllow("anything") {
|
||||
t.Error("* should allow anything")
|
||||
}
|
||||
if !m.ShouldAllow("") {
|
||||
t.Error("* should allow empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_ExactDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"secret-token"},
|
||||
})
|
||||
if m.ShouldAllow("secret-token") {
|
||||
t.Error("should deny secret-token")
|
||||
}
|
||||
if !m.ShouldAllow("public-key") {
|
||||
t.Error("should allow public-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_WildcardDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"x-internal-*"},
|
||||
})
|
||||
if m.ShouldAllow("x-internal-id") {
|
||||
t.Error("should deny x-internal-id")
|
||||
}
|
||||
if m.ShouldAllow("x-internal-secret") {
|
||||
t.Error("should deny x-internal-secret")
|
||||
}
|
||||
if !m.ShouldAllow("x-external-id") {
|
||||
t.Error("should allow x-external-id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_StarDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"*"},
|
||||
})
|
||||
if m.ShouldAllow("anything") {
|
||||
t.Error("* denylist should deny everything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_AllowlistWithDenylist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"*"},
|
||||
Denylist: []string{"x-internal-*"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if m.ShouldAllow("x-internal-id") {
|
||||
t.Error("should deny x-internal-id (denylist overrides)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_AllowlistPrefixWithDenylistExact(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*"},
|
||||
Denylist: []string{"anthropic-dangerous"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta")
|
||||
}
|
||||
if m.ShouldAllow("anthropic-dangerous") {
|
||||
t.Error("should deny anthropic-dangerous")
|
||||
}
|
||||
if m.ShouldAllow("openai-version") {
|
||||
t.Error("should not allow openai-version (not in allowlist)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_CaseInsensitive(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"Anthropic-*"},
|
||||
Denylist: []string{"X-Internal-*"},
|
||||
})
|
||||
if !m.ShouldAllow("anthropic-beta") {
|
||||
t.Error("should allow anthropic-beta (case insensitive)")
|
||||
}
|
||||
if m.ShouldAllow("x-internal-id") {
|
||||
t.Error("should deny x-internal-id (case insensitive)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_MatchesAllow(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"anthropic-*", "custom-id"},
|
||||
})
|
||||
if !m.MatchesAllow("anthropic-beta") {
|
||||
t.Error("should match anthropic-beta")
|
||||
}
|
||||
if !m.MatchesAllow("custom-id") {
|
||||
t.Error("should match custom-id")
|
||||
}
|
||||
if m.MatchesAllow("openai-version") {
|
||||
t.Error("should not match openai-version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_MatchesDeny(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"secret-*", "blocked"},
|
||||
})
|
||||
if !m.MatchesDeny("secret-token") {
|
||||
t.Error("should match secret-token")
|
||||
}
|
||||
if !m.MatchesDeny("blocked") {
|
||||
t.Error("should match blocked")
|
||||
}
|
||||
if m.MatchesDeny("allowed") {
|
||||
t.Error("should not match allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderMatcher_HasAllowlist(t *testing.T) {
|
||||
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Allowlist: []string{"foo"},
|
||||
})
|
||||
if !m.HasAllowlist() {
|
||||
t.Error("should have allowlist")
|
||||
}
|
||||
|
||||
m2 := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
|
||||
Denylist: []string{"bar"},
|
||||
})
|
||||
if m2.HasAllowlist() {
|
||||
t.Error("should not have allowlist")
|
||||
}
|
||||
}
|
||||
58
transports/bifrost-http/lib/lib.go
Normal file
58
transports/bifrost-http/lib/lib.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var logger schemas.Logger
|
||||
|
||||
// SetLogger sets the logger for the application.
|
||||
func SetLogger(l schemas.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
// StreamLargeResponseBody extracts the large response reader from context and streams
|
||||
// it directly to the client. Sets status 200, content-type, and content-length headers.
|
||||
// Returns false if the reader is not available (caller should send an error response).
|
||||
func StreamLargeResponseBody(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
if bifrostCtx == nil {
|
||||
return false
|
||||
}
|
||||
reader, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseReader).(io.ReadCloser)
|
||||
if !ok || reader == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
contentLength, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentLength).(int)
|
||||
contentType, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentType).(string)
|
||||
contentDisposition, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentDisposition).(string)
|
||||
|
||||
// Mirror large-response-mode to fasthttp UserValue so post-hook middleware
|
||||
// (which only sees ctx.UserValue, not bifrostCtx) can skip body materialization.
|
||||
ctx.SetUserValue(FastHTTPUserValueLargeResponseMode, true)
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
if contentType != "" {
|
||||
ctx.SetContentType(contentType)
|
||||
} else {
|
||||
ctx.SetContentType("application/json")
|
||||
}
|
||||
if contentDisposition != "" {
|
||||
ctx.Response.Header.Set("Content-Disposition", contentDisposition)
|
||||
}
|
||||
// bodySize for SetBodyStream: positive = known size, -1 = unknown (read until EOF).
|
||||
// fasthttp treats 0 as "known empty", so default to -1 when CL is unavailable.
|
||||
bodySize := contentLength
|
||||
if bodySize > 0 {
|
||||
ctx.Response.Header.Set("Content-Length", strconv.Itoa(contentLength))
|
||||
} else {
|
||||
bodySize = -1
|
||||
}
|
||||
|
||||
ctx.Response.SetBodyStream(reader, bodySize)
|
||||
return true
|
||||
}
|
||||
23
transports/bifrost-http/lib/middleware.go
Normal file
23
transports/bifrost-http/lib/middleware.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ChainMiddlewares chains multiple middlewares together
|
||||
// Middlewares are applied in order: the first middleware wraps the second, etc.
|
||||
// This allows earlier middlewares to short-circuit by not calling next(ctx)
|
||||
func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...schemas.BifrostHTTPMiddleware) fasthttp.RequestHandler {
|
||||
// If no middlewares, return the original handler
|
||||
if len(middlewares) == 0 {
|
||||
return handler
|
||||
}
|
||||
// Build the chain from right to left (last middleware wraps the handler)
|
||||
// This ensures execution order is left to right (first middleware executes first)
|
||||
chained := handler
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
chained = middlewares[i](chained)
|
||||
}
|
||||
return chained
|
||||
}
|
||||
1077
transports/bifrost-http/lib/pricing_integration_test.go
Normal file
1077
transports/bifrost-http/lib/pricing_integration_test.go
Normal file
File diff suppressed because it is too large
Load Diff
174
transports/bifrost-http/lib/semantic_cache_config_test.go
Normal file
174
transports/bifrost-http/lib/semantic_cache_config_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/configstore"
|
||||
"github.com/maximhq/bifrost/plugins/semanticcache"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 1,
|
||||
"ttl": "5m",
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
configMap, ok := pluginConfig.Config.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
_, hasKeys := configMap["keys"]
|
||||
require.False(t, hasKeys, "direct-only mode should not inject provider keys")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 1,
|
||||
"keys": []schemas.Key{{Name: "stale-key"}},
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
configMap, ok := pluginConfig.Config.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
_, hasKeys := configMap["keys"]
|
||||
require.False(t, hasKeys, "direct-only mode should remove stale provider keys")
|
||||
_, hasEmbeddingModel := configMap["embedding_model"]
|
||||
require.False(t, hasEmbeddingModel, "direct-only mode should remove stale embedding_model")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) {
|
||||
config := &Config{
|
||||
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
|
||||
schemas.OpenAI: {
|
||||
Keys: []schemas.Key{
|
||||
{
|
||||
Name: "openai-key",
|
||||
Value: *schemas.NewEnvVar("sk-test"),
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"dimension": 1536,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
configMap, ok := pluginConfig.Config.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
keys, ok := configMap["keys"].([]schemas.Key)
|
||||
require.True(t, ok, "provider-backed mode should inject provider keys")
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, "openai-key", keys[0].Name)
|
||||
require.Equal(t, "openai", configMap["provider"])
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 1536,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'provider' for semantic mode")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'dimension' for provider-backed semantic mode")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"dimension": 1,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'dimension' > 1")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"provider": "openai",
|
||||
"dimension": 1536,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requires 'embedding_model'")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": 0,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "'dimension' must be >= 1")
|
||||
}
|
||||
|
||||
func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) {
|
||||
config := &Config{}
|
||||
pluginConfig := &schemas.PluginConfig{
|
||||
Name: semanticcache.PluginName,
|
||||
Config: map[string]interface{}{
|
||||
"dimension": -1,
|
||||
},
|
||||
}
|
||||
|
||||
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "'dimension' must be >= 1")
|
||||
}
|
||||
114
transports/bifrost-http/lib/streamreader.go
Normal file
114
transports/bifrost-http/lib/streamreader.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SSEStreamReader is an io.ReadCloser that delivers one event per Read call,
|
||||
// bypassing fasthttp's internal pipe mechanism (fasthttputil.PipeConns) which
|
||||
// batches multiple events into single TCP segments.
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create with NewSSEStreamReader()
|
||||
// 2. Pass to ctx.Response.SetBodyStream(reader, -1)
|
||||
// 3. Start a producer goroutine that calls Send()/SendEvent()/SendError() for each event
|
||||
// 4. Producer calls Done() when finished (closes the event channel)
|
||||
// 5. fasthttp calls Close() on write errors (signals producer to stop)
|
||||
type SSEStreamReader struct {
|
||||
eventCh chan []byte
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
current []byte // remaining bytes from a partial read
|
||||
}
|
||||
|
||||
// NewSSEStreamReader creates a new SSEStreamReader with a buffered event channel.
|
||||
// Channel capacity of 1 allows one event of pipeline parallelism between
|
||||
// the producer goroutine and fasthttp's writeBodyChunked loop.
|
||||
func NewSSEStreamReader() *SSEStreamReader {
|
||||
return &SSEStreamReader{
|
||||
eventCh: make(chan []byte, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It blocks until an event is available, then returns
|
||||
// that event's bytes. If the caller's buffer is smaller than the event, remaining
|
||||
// bytes are stored and returned on subsequent calls. Returns io.EOF when Done()
|
||||
// has been called and all events have been consumed.
|
||||
func (r *SSEStreamReader) Read(p []byte) (int, error) {
|
||||
if len(r.current) == 0 {
|
||||
event, ok := <-r.eventCh
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
r.current = event
|
||||
}
|
||||
n := copy(p, r.current)
|
||||
r.current = r.current[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close implements io.Closer. Called by fasthttp when writeBodyChunked encounters
|
||||
// a write error (client disconnect). Signals the producer goroutine to stop via closeCh.
|
||||
// Safe to call multiple times.
|
||||
func (r *SSEStreamReader) Close() error {
|
||||
r.closeOnce.Do(func() {
|
||||
close(r.closeCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send delivers a pre-formatted event to the reader. Returns false if the reader
|
||||
// has been closed (client disconnected), in which case the producer should stop.
|
||||
func (r *SSEStreamReader) Send(event []byte) bool {
|
||||
// Check closeCh first (non-blocking) to avoid sending after Close
|
||||
select {
|
||||
case <-r.closeCh:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case r.eventCh <- event:
|
||||
return true
|
||||
case <-r.closeCh:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SendEvent sends an SSE-framed event. If eventType is empty, it sends "data: <data>\n\n".
|
||||
// If eventType is non-empty, it sends "event: <eventType>\ndata: <data>\n\n".
|
||||
// Returns false if the reader has been closed (client disconnected).
|
||||
func (r *SSEStreamReader) SendEvent(eventType string, data []byte) bool {
|
||||
var buf []byte
|
||||
if eventType != "" {
|
||||
buf = make([]byte, 0, 7+len(eventType)+7+len(data)+2)
|
||||
buf = append(buf, "event: "...)
|
||||
buf = append(buf, eventType...)
|
||||
buf = append(buf, "\ndata: "...)
|
||||
} else {
|
||||
buf = make([]byte, 0, 6+len(data)+2)
|
||||
buf = append(buf, "data: "...)
|
||||
}
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return r.Send(buf)
|
||||
}
|
||||
|
||||
// SendError sends an SSE error event: "event: error\ndata: <data>\n\n".
|
||||
// Returns false if the reader has been closed (client disconnected).
|
||||
func (r *SSEStreamReader) SendError(data []byte) bool {
|
||||
return r.SendEvent("error", data)
|
||||
}
|
||||
|
||||
// SendDone sends the standard SSE done marker: "data: [DONE]\n\n".
|
||||
// Returns false if the reader has been closed (client disconnected).
|
||||
func (r *SSEStreamReader) SendDone() bool {
|
||||
return r.Send([]byte("data: [DONE]\n\n"))
|
||||
}
|
||||
|
||||
// Done closes the event channel, signaling to Read that the stream is finished.
|
||||
// Must be called exactly once by the producer goroutine when streaming is complete.
|
||||
func (r *SSEStreamReader) Done() {
|
||||
close(r.eventCh)
|
||||
}
|
||||
714
transports/bifrost-http/lib/streamreader_test.go
Normal file
714
transports/bifrost-http/lib/streamreader_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSSEStreamReaderSingleEventPerRead(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
events := [][]byte{
|
||||
[]byte("data: {\"chunk\":1}\n\n"),
|
||||
[]byte("data: {\"chunk\":2}\n\n"),
|
||||
[]byte("data: {\"chunk\":3}\n\n"),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
for _, e := range events {
|
||||
if !r.Send(e) {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("Send returned false unexpectedly"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for i, want := range events {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
got := string(buf[:n])
|
||||
if got != string(want) {
|
||||
t.Errorf("event %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Next read should return EOF
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Error(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderPartialRead(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
event := []byte("data: {\"content\":\"hello world\"}\n\n")
|
||||
|
||||
go func() {
|
||||
r.Send(event)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Read with a small buffer (5 bytes at a time)
|
||||
var result []byte
|
||||
buf := make([]byte, 5)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
result = append(result, buf[:n]...)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(result) != string(event) {
|
||||
t.Errorf("reassembled data: got %q, want %q", result, event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderEOFOnDone(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
r.Done() // Close immediately
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderCloseSignalsProducer(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
r.Close()
|
||||
|
||||
if r.Send([]byte("data: test\n\n")) {
|
||||
t.Error("Send should return false after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderIdempotentClose(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
// Should not panic
|
||||
r.Close()
|
||||
r.Close()
|
||||
r.Close()
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderConcurrent(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
const numEvents = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
// Producer
|
||||
go func() {
|
||||
for i := 0; i < numEvents; i++ {
|
||||
if !r.Send([]byte("data: event\n\n")) {
|
||||
break
|
||||
}
|
||||
}
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Consumer
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 4096)
|
||||
count := 0
|
||||
for {
|
||||
_, err := r.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("unexpected error: %v", err):
|
||||
default:
|
||||
}
|
||||
break
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != numEvents {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("got %d events, want %d", count, numEvents):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType string
|
||||
data []byte
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "data only",
|
||||
eventType: "",
|
||||
data: []byte(`{"chunk":1}`),
|
||||
want: "data: {\"chunk\":1}\n\n",
|
||||
},
|
||||
{
|
||||
name: "with event type",
|
||||
eventType: "response.delta",
|
||||
data: []byte(`{"delta":"hi"}`),
|
||||
want: "event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent(tt.eventType, tt.data)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got := string(buf[:n]); got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendError(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendError([]byte(`{"error":"bad"}`))
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := "event: error\ndata: {\"error\":\"bad\"}\n\n"
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendDone(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendDone()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := "data: [DONE]\n\n"
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderSendEventAfterClose(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
r.Close()
|
||||
|
||||
if r.SendEvent("test", []byte("data")) {
|
||||
t.Error("SendEvent should return false after Close")
|
||||
}
|
||||
if r.SendError([]byte("err")) {
|
||||
t.Error("SendError should return false after Close")
|
||||
}
|
||||
if r.SendDone() {
|
||||
t.Error("SendDone should return false after Close")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventByteAccuracy verifies that SendEvent produces
|
||||
// the exact same bytes that the old manual buffer assembly in the handlers did.
|
||||
func TestSSEStreamReaderSendEventByteAccuracy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType string
|
||||
data []byte
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "standard SSE data (old inference.go pattern)",
|
||||
eventType: "",
|
||||
data: []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`),
|
||||
want: func() []byte {
|
||||
// Old code: buf = append(buf, "data: "...); buf = append(buf, chunkJSON...); buf = append(buf, '\n', '\n')
|
||||
data := []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`)
|
||||
buf := make([]byte, 0, len(data)+8)
|
||||
buf = append(buf, "data: "...)
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return buf
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "OpenAI responses format with event type (old inference.go pattern)",
|
||||
eventType: "response.output_item.added",
|
||||
data: []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`),
|
||||
want: func() []byte {
|
||||
// Old code: buf = append(buf, "event: "...); buf = append(buf, eventType...); buf = append(buf, "\ndata: "...); ...
|
||||
eventType := "response.output_item.added"
|
||||
data := []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`)
|
||||
buf := make([]byte, 0, len(eventType)+len(data)+16)
|
||||
buf = append(buf, "event: "...)
|
||||
buf = append(buf, eventType...)
|
||||
buf = append(buf, "\ndata: "...)
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return buf
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "error event (old interceptor pattern)",
|
||||
eventType: "error",
|
||||
data: []byte(`{"error":"stream interrupted"}`),
|
||||
want: func() []byte {
|
||||
data := []byte(`{"error":"stream interrupted"}`)
|
||||
buf := make([]byte, 0, len(data)+24)
|
||||
buf = append(buf, "event: error\ndata: "...)
|
||||
buf = append(buf, data...)
|
||||
buf = append(buf, '\n', '\n')
|
||||
return buf
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent(tt.eventType, tt.data)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
got := buf[:n]
|
||||
if string(got) != string(tt.want) {
|
||||
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendErrorByteAccuracy verifies SendError matches
|
||||
// the old "event: error\ndata: ..." manual assembly.
|
||||
func TestSSEStreamReaderSendErrorByteAccuracy(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
errorJSON := []byte(`{"error":{"type":"internal_error","message":"An error occurred"}}`)
|
||||
|
||||
go func() {
|
||||
r.SendError(errorJSON)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Must match the old pattern exactly:
|
||||
// buf = append(buf, "event: error\ndata: "...)
|
||||
// buf = append(buf, errorJSON...)
|
||||
// buf = append(buf, '\n', '\n')
|
||||
want := "event: error\ndata: " + string(errorJSON) + "\n\n"
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderMixedMethodStream simulates a realistic stream
|
||||
// that uses multiple methods (like router.go does): data events,
|
||||
// typed events, error, and done marker.
|
||||
func TestSSEStreamReaderMixedMethodStream(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
expected := []string{
|
||||
"data: {\"chunk\":1}\n\n",
|
||||
"event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
|
||||
"data: {\"chunk\":2}\n\n",
|
||||
"event: error\ndata: {\"error\":\"timeout\"}\n\n",
|
||||
"data: [DONE]\n\n",
|
||||
}
|
||||
|
||||
go func() {
|
||||
r.SendEvent("", []byte(`{"chunk":1}`))
|
||||
r.SendEvent("response.delta", []byte(`{"delta":"hi"}`))
|
||||
r.SendEvent("", []byte(`{"chunk":2}`))
|
||||
r.SendError([]byte(`{"error":"timeout"}`))
|
||||
r.SendDone()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for i, want := range expected {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("event %d:\n got: %q\n want: %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be EOF after all events
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected EOF, got err=%v n=%d", err, n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderRawAndWrapperMixed simulates the router.go pattern
|
||||
// where raw Send (for Bedrock/passthrough) is mixed with wrapper methods.
|
||||
func TestSSEStreamReaderRawAndWrapperMixed(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
// Simulate: Bedrock binary event (raw), followed by SSE events, then done
|
||||
bedrockBinary := []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x2A} // fake binary
|
||||
preformattedSSE := []byte("event: content_block_delta\ndata: {\"delta\":\"test\"}\n\n")
|
||||
|
||||
expected := [][]byte{
|
||||
bedrockBinary,
|
||||
preformattedSSE,
|
||||
[]byte("data: {\"final\":true}\n\n"),
|
||||
}
|
||||
|
||||
go func() {
|
||||
r.Send(bedrockBinary) // raw binary passthrough
|
||||
r.Send(preformattedSSE) // pre-formatted SSE string
|
||||
r.SendEvent("", []byte(`{"final":true}`)) // wrapper method
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for i, want := range expected {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if string(buf[:n]) != string(want) {
|
||||
t.Errorf("event %d:\n got: %q\n want: %q", i, buf[:n], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventEmptyData verifies behavior with empty data payload.
|
||||
func TestSSEStreamReaderSendEventEmptyData(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eventType string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty data no event type",
|
||||
eventType: "",
|
||||
want: "data: \n\n",
|
||||
},
|
||||
{
|
||||
name: "empty data with event type",
|
||||
eventType: "heartbeat",
|
||||
want: "event: heartbeat\ndata: \n\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent(tt.eventType, []byte{})
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got := string(buf[:n]); got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventNilData verifies behavior with nil data payload.
|
||||
func TestSSEStreamReaderSendEventNilData(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendEvent("", nil)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// nil data should produce same as empty: "data: \n\n"
|
||||
if got := string(buf[:n]); got != "data: \n\n" {
|
||||
t.Errorf("got %q, want %q", got, "data: \n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendEventLargePayload verifies no corruption with large JSON payloads.
|
||||
func TestSSEStreamReaderSendEventLargePayload(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
// Build a large JSON payload (~64KB, larger than typical ReadBufferSize)
|
||||
largeContent := make([]byte, 65536)
|
||||
for i := range largeContent {
|
||||
largeContent[i] = 'A' + byte(i%26)
|
||||
}
|
||||
data := append([]byte(`{"content":"`), largeContent...)
|
||||
data = append(data, '"', '}')
|
||||
|
||||
go func() {
|
||||
r.SendEvent("response.delta", data)
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Read the entire event using small buffer to exercise partial reads
|
||||
var result []byte
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
result = append(result, buf[:n]...)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
want := "event: response.delta\ndata: " + string(data) + "\n\n"
|
||||
if string(result) != want {
|
||||
t.Errorf("large payload mismatch: got len=%d, want len=%d", len(result), len(want))
|
||||
// Check prefix and suffix for debugging
|
||||
if len(result) > 40 {
|
||||
t.Errorf(" got prefix: %q", result[:40])
|
||||
t.Errorf(" want prefix: %q", want[:40])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderMidStreamDisconnect simulates a client disconnecting
|
||||
// mid-stream while the producer is using SendEvent.
|
||||
func TestSSEStreamReaderMidStreamDisconnect(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
producerDone := make(chan int) // reports how many events were sent
|
||||
go func() {
|
||||
sent := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
if !r.SendEvent("", []byte(fmt.Sprintf(`{"chunk":%d}`, i))) {
|
||||
break
|
||||
}
|
||||
sent++
|
||||
}
|
||||
close(producerDone)
|
||||
}()
|
||||
|
||||
// Read a few events then simulate client disconnect
|
||||
buf := make([]byte, 4096)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Client disconnects
|
||||
r.Close()
|
||||
|
||||
// Producer should stop promptly
|
||||
<-producerDone
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendErrorThenDone verifies the handler pattern
|
||||
// of sending an error event and immediately closing the stream.
|
||||
func TestSSEStreamReaderSendErrorThenDone(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
go func() {
|
||||
// Send a few normal events
|
||||
r.SendEvent("", []byte(`{"chunk":1}`))
|
||||
r.SendEvent("", []byte(`{"chunk":2}`))
|
||||
// Error occurs, send error and stop
|
||||
r.SendError([]byte(`{"error":"rate_limit"}`))
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
expected := []string{
|
||||
"data: {\"chunk\":1}\n\n",
|
||||
"data: {\"chunk\":2}\n\n",
|
||||
"event: error\ndata: {\"error\":\"rate_limit\"}\n\n",
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("event %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if got := string(buf[:n]); got != want {
|
||||
t.Errorf("event %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be EOF (stream ended after error, no [DONE] marker)
|
||||
n, err := r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected EOF after error event, got err=%v n=%d data=%q", err, n, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderSendDoneByteExact verifies SendDone produces
|
||||
// exactly "data: [DONE]\n\n" — the standard OpenAI SSE terminator.
|
||||
func TestSSEStreamReaderSendDoneByteExact(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
go func() {
|
||||
r.SendDone()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Use exact-size buffer to verify no extra bytes
|
||||
want := []byte("data: [DONE]\n\n")
|
||||
buf := make([]byte, len(want))
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(want) {
|
||||
t.Errorf("expected %d bytes, got %d", len(want), n)
|
||||
}
|
||||
if string(buf[:n]) != string(want) {
|
||||
t.Errorf("got %q, want %q", buf[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSEStreamReaderConcurrentSendEvent verifies thread safety of SendEvent
|
||||
// with multiple concurrent producers (not a real pattern but validates safety).
|
||||
func TestSSEStreamReaderConcurrentSendEvent(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
const numProducers = 5
|
||||
const eventsPerProducer = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numProducers)
|
||||
|
||||
// Launch multiple producers
|
||||
for p := 0; p < numProducers; p++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < eventsPerProducer; i++ {
|
||||
if !r.SendEvent("", []byte(fmt.Sprintf(`{"p":%d,"i":%d}`, id, i))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
|
||||
// Close after all producers finish
|
||||
go func() {
|
||||
wg.Wait()
|
||||
r.Done()
|
||||
}()
|
||||
|
||||
// Consume all events
|
||||
buf := make([]byte, 4096)
|
||||
count := 0
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Every event must be a valid SSE data line
|
||||
got := string(buf[:n])
|
||||
if len(got) < 8 || got[:6] != "data: " || got[len(got)-2:] != "\n\n" {
|
||||
t.Errorf("event %d: invalid SSE format: %q", count, got)
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
if count != numProducers*eventsPerProducer {
|
||||
t.Errorf("got %d events, want %d", count, numProducers*eventsPerProducer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEStreamReaderCloseUnblocksProducer(t *testing.T) {
|
||||
r := NewSSEStreamReader()
|
||||
|
||||
done := make(chan struct{})
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(done)
|
||||
// Fill the channel buffer (cap=1)
|
||||
r.Send([]byte("data: first\n\n"))
|
||||
// This Send should block until Close is called
|
||||
r.Send([]byte("data: second\n\n"))
|
||||
// After Close, the next Send should return false
|
||||
if r.Send([]byte("data: third\n\n")) {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("Send should return false after Close"):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Close unblocks the blocked Send
|
||||
r.Close()
|
||||
<-done
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Error(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
96
transports/bifrost-http/lib/validator.go
Normal file
96
transports/bifrost-http/lib/validator.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Package lib provides core functionality for the Bifrost HTTP service.
|
||||
// This file contains JSON schema validation for config files.
|
||||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/santhosh-tekuri/jsonschema/v6"
|
||||
)
|
||||
|
||||
// localSchemaCandidates lists paths (relative to CWD) where config.schema.json may be found
|
||||
// when running from a source checkout. Checked in order before falling back to the remote URL.
|
||||
var localSchemaCandidates = []string{
|
||||
"config.schema.json", // running from transports/
|
||||
"../config.schema.json", // running from transports/bifrost-http/
|
||||
"transports/config.schema.json", // running from repo root
|
||||
}
|
||||
|
||||
// tryLoadLocalSchema attempts to read config.schema.json from known local paths.
|
||||
// Returns nil if none are found.
|
||||
func tryLoadLocalSchema() []byte {
|
||||
for _, p := range localSchemaCandidates {
|
||||
data, err := os.ReadFile(p)
|
||||
if err == nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConfigSchema validates config data against the JSON schema.
|
||||
// Returns nil if valid, or a formatted error describing all validation failures.
|
||||
// An optional schemaOverride can be provided to use a local schema instead of fetching from the remote URL.
|
||||
func ValidateConfigSchema(data []byte, schemaOverride ...[]byte) error {
|
||||
var configSchemaJSONBytes []byte
|
||||
if len(schemaOverride) > 0 && len(schemaOverride[0]) > 0 {
|
||||
configSchemaJSONBytes = schemaOverride[0]
|
||||
} else if localSchema := tryLoadLocalSchema(); localSchema != nil {
|
||||
// Prefer the local schema file from the source checkout when available.
|
||||
// This avoids validating against a potentially stale remote schema.
|
||||
configSchemaJSONBytes = localSchema
|
||||
} else {
|
||||
// Pulling config.schema from https://www.getbifrost.ai/schema
|
||||
configSchemaJSON, err := http.Get("https://www.getbifrost.ai/schema")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config schema: %w", err)
|
||||
}
|
||||
defer configSchemaJSON.Body.Close()
|
||||
var readErr error
|
||||
configSchemaJSONBytes, readErr = io.ReadAll(configSchemaJSON.Body)
|
||||
if readErr != nil {
|
||||
logger.Warn("failed to download config schema: %v. running without config.json schema validation", readErr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Parse the schema JSON
|
||||
schemaDoc, err := jsonschema.UnmarshalJSON(bytes.NewReader(configSchemaJSONBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse config schema JSON: %w", err)
|
||||
}
|
||||
c := jsonschema.NewCompiler()
|
||||
if err := c.AddResource("config.schema.json", schemaDoc); err != nil {
|
||||
return fmt.Errorf("failed to add config schema resource: %w", err)
|
||||
}
|
||||
// Compile the schema
|
||||
compiledSchema, err := c.Compile("config.schema.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compile config schema: %w", err)
|
||||
}
|
||||
var v any
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
err = compiledSchema.Validate(v)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// Format validation errors for better readability
|
||||
return formatValidationError(err)
|
||||
}
|
||||
|
||||
// formatValidationError converts jsonschema validation errors into user-friendly messages
|
||||
func formatValidationError(err error) error {
|
||||
validationErr, ok := err.(*jsonschema.ValidationError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the GoString format which provides detailed hierarchical output
|
||||
return fmt.Errorf("schema validation failed:\n%s", validationErr.GoString())
|
||||
}
|
||||
1476
transports/bifrost-http/lib/validator_test.go
Normal file
1476
transports/bifrost-http/lib/validator_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user