first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,105 @@
// Package lib provides core functionality for the Bifrost HTTP service,
// including context propagation, header management, and integration with monitoring systems.
package lib
import (
"context"
"fmt"
"github.com/maximhq/bifrost/core/schemas"
)
// BaseAccount implements the Account interface for Bifrost.
// It manages provider configurations using a in-memory store for persistent storage.
// All data processing (environment variables, key configs) is done upfront in the store.
type BaseAccount struct {
store *Config // store for in-memory configuration
}
// NewBaseAccount creates a new BaseAccount with the given store
func NewBaseAccount(store *Config) *BaseAccount {
return &BaseAccount{
store: store,
}
}
// GetConfiguredProviders returns a list of all configured providers.
// Implements the Account interface.
func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
if baseAccount.store == nil {
return nil, fmt.Errorf("store not initialized")
}
return baseAccount.store.GetAllProviders()
}
// GetKeysForProvider returns the API keys configured for a specific provider.
// Keys are already processed (environment variables resolved) by the store.
// Implements the Account interface.
func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) {
if baseAccount.store == nil {
return nil, fmt.Errorf("store not initialized")
}
config, err := baseAccount.store.GetProviderConfigRaw(providerKey)
if err != nil {
return nil, err
}
keys := config.Keys
if v := ctx.Value(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys); v != nil {
if includeOnlyKeys, ok := v.([]string); ok {
if len(includeOnlyKeys) == 0 {
// header present but empty means "no keys allowed"
keys = nil
} else {
set := make(map[string]struct{}, len(includeOnlyKeys))
for _, id := range includeOnlyKeys {
set[id] = struct{}{}
}
filtered := make([]schemas.Key, 0, len(keys))
for _, key := range keys {
if _, ok := set[key.ID]; ok {
filtered = append(filtered, key)
}
}
keys = filtered
}
}
}
return keys, nil
}
// GetConfigForProvider returns the complete configuration for a specific provider.
// Configuration is already fully processed (environment variables, key configs) by the store.
// Implements the Account interface.
func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) {
if baseAccount.store == nil {
return nil, fmt.Errorf("store not initialized")
}
config, err := baseAccount.store.GetProviderConfigRaw(providerKey)
if err != nil {
return nil, err
}
providerConfig := &schemas.ProviderConfig{}
if config.ProxyConfig != nil {
providerConfig.ProxyConfig = config.ProxyConfig
}
if config.NetworkConfig != nil {
providerConfig.NetworkConfig = *config.NetworkConfig
} else {
providerConfig.NetworkConfig = schemas.DefaultNetworkConfig
}
if config.ConcurrencyAndBufferSize != nil {
providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize
} else {
providerConfig.ConcurrencyAndBufferSize = schemas.DefaultConcurrencyAndBufferSize
}
providerConfig.SendBackRawRequest = config.SendBackRawRequest
providerConfig.SendBackRawResponse = config.SendBackRawResponse
providerConfig.StoreRawRequestResponse = config.StoreRawRequestResponse
if config.CustomProviderConfig != nil {
providerConfig.CustomProviderConfig = config.CustomProviderConfig
}
if config.OpenAIConfig != nil {
providerConfig.OpenAIConfig = config.OpenAIConfig
}
return providerConfig, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,644 @@
// Package lib provides core functionality for the Bifrost HTTP service,
// including context propagation, header management, and integration with monitoring systems.
//
// This package handles the conversion of FastHTTP request contexts to Bifrost contexts,
// ensuring that important metadata and tracking information is preserved across the system.
// It supports propagation of both Prometheus metrics and Maxim tracing data through HTTP headers.
package lib
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/plugins/governance"
"github.com/maximhq/bifrost/plugins/maxim"
"github.com/maximhq/bifrost/plugins/semanticcache"
"github.com/valyala/fasthttp"
)
const (
// FastHTTPUserValueBifrostContext stores the active *schemas.BifrostContext on fasthttp.RequestCtx.
// This allows transport middleware and request handlers to share the same context instance.
FastHTTPUserValueBifrostContext = "__bifrost_context"
// FastHTTPUserValueBifrostCancel stores the cancel func for the active shared Bifrost context.
FastHTTPUserValueBifrostCancel = "__bifrost_context_cancel"
// FastHTTPUserValueLargeResponseMode marks requests that streamed a large response body.
// It is used by transport middleware to avoid re-buffering response bodies for post-hooks.
FastHTTPUserValueLargeResponseMode = "__bifrost_large_response_mode"
)
// ParseSessionIDFromBaggage extracts the session-id baggage member value.
// It supports simple W3C baggage parsing sufficient for log grouping.
func ParseSessionIDFromBaggage(header string) string {
for _, member := range strings.Split(header, ",") {
member = strings.TrimSpace(member)
if member == "" {
continue
}
parts := strings.SplitN(member, ";", 2)
kv := strings.SplitN(strings.TrimSpace(parts[0]), "=", 2)
if len(kv) != 2 {
continue
}
key := strings.ToLower(strings.TrimSpace(kv[0]))
value := strings.TrimSpace(kv[1])
if key != "session-id" || value == "" {
continue
}
if len(value) > 255 {
if logger != nil {
logger.Warn("session-id exceeds 255 chars, ignoring: length=%d, prefix=%s", len(value), value[:255])
}
continue
}
return value
}
return ""
}
// ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context,
// preserving important header values for monitoring and tracing purposes.
//
// The function processes several types of special headers:
// 1. Prometheus Headers (x-bf-prom-*):
// - All headers prefixed with 'x-bf-prom-' are copied to the context
// - The prefix is stripped and the remainder becomes the context key
// - Example: 'x-bf-prom-latency' becomes 'latency' in the context
//
// 2. Maxim Tracing Headers (x-bf-maxim-*):
// - Specifically handles 'x-bf-maxim-traceID' and 'x-bf-maxim-generationID'
// - These headers enable trace correlation across service boundaries
// - Values are stored using Maxim's context keys for consistency
//
// 3. MCP Headers (x-bf-mcp-*):
// - Specifically handles 'x-bf-mcp-include-clients' and 'x-bf-mcp-include-tools' (include-only filtering)
// - These headers enable MCP client and tool filtering
// - Values are stored using MCP context keys for consistency
//
// 4. Governance Headers:
// - x-bf-vk: Virtual key for governance (required for governance to work)
//
// 5. API Key Headers:
// - Authorization: Bearer token format only (e.g., "Bearer sk-...") - OpenAI style
// - x-api-key: Direct API key value - Anthropic style
// - x-goog-api-key: Direct API key value - Google Gemini style
// - x-bf-api-key references a stored API key name rather than the raw secret.
// - Keys are extracted and stored in the context using schemas.BifrostContextKey
// - This enables explicit key usage for requests via headers
//
// 6. Cancellable Context:
// - Creates a cancellable context that can be used to cancel upstream requests when clients disconnect
// - This is critical for streaming requests where write errors indicate client disconnects
// - Also useful for non-streaming requests to allow provider-level cancellation
//
// 7. Extra Headers (x-bf-eh-*):
// - Any header starting with 'x-bf-eh-' is collected and added to the map stored under schemas.BifrostContextKeyExtraHeaders
// - The prefix is stripped, the remainder is lower-cased, and duplicate names append values
// - This allows callers to send arbitrary context metadata without needing to extend the public schema
//
// 8. Session Stickiness Headers:
// - x-bf-session-id: Session identifier for key binding (reuse same key across requests)
// - x-bf-session-ttl: Per-request TTL override (duration string e.g. "30m" or seconds integer)
//
// 9. Raw Capture Headers (per-request override of provider config; accepts "true" or "false"):
// - x-bf-send-back-raw-request: include raw provider request in the BifrostResponse returned to the caller
// - x-bf-send-back-raw-response: include raw provider response in the BifrostResponse returned to the caller
// - x-bf-store-raw-request-response: capture raw request/response for logging only (stripped from client response)
// Parameters:
// - ctx: The FastHTTP request context containing the original headers
// - allowDirectKeys: Whether to allow direct API key usage from headers
//
// Returns:
// - *context.Context: A new cancellable context.Context containing the propagated values
// - context.CancelFunc: Function to cancel the context (should be called when request completes)
//
// Example Usage:
//
// fastCtx := &fasthttp.RequestCtx{...}
// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, true, nil)
// defer cancel() // Ensure cleanup
// // bifrostCtx now contains propagated header values including Prometheus metrics,
// // Maxim tracing data, MCP filters, governance keys, API keys, cache settings,
// // session stickiness, and extra headers
func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher, mcpHeaderCombinedAllowlist schemas.WhiteList) (*schemas.BifrostContext, context.CancelFunc) {
// Reuse a shared request-scoped context when available.
var bifrostCtx *schemas.BifrostContext
var cancel context.CancelFunc
if existing, ok := ctx.UserValue(FastHTTPUserValueBifrostContext).(*schemas.BifrostContext); ok && existing != nil {
if existingCancel, ok := ctx.UserValue(FastHTTPUserValueBifrostCancel).(context.CancelFunc); ok && existingCancel != nil {
bifrostCtx = existing
cancel = existingCancel
} else {
// Create one cancellable child context and promote it as the shared context.
bifrostCtx, cancel = schemas.NewBifrostContextWithCancel(existing)
ctx.SetUserValue(FastHTTPUserValueBifrostContext, bifrostCtx)
ctx.SetUserValue(FastHTTPUserValueBifrostCancel, cancel)
}
}
if bifrostCtx == nil {
// Create cancellable context for requests that don't have a shared context yet.
parent := context.Context(ctx)
func() {
// Zero-value fasthttp.RequestCtx can panic on Done(); fall back safely.
defer func() {
if recover() != nil {
parent = context.Background()
}
}()
_ = ctx.Done()
}()
bifrostCtx, cancel = schemas.NewBifrostContextWithCancel(parent)
ctx.SetUserValue(FastHTTPUserValueBifrostContext, bifrostCtx)
ctx.SetUserValue(FastHTTPUserValueBifrostCancel, cancel)
}
// Preserve existing request-id if already present on the shared context.
if existingRequestID, ok := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string); !ok || existingRequestID == "" {
// First, check if x-request-id header exists
requestID := string(ctx.Request.Header.Peek("x-request-id"))
if requestID == "" {
requestID = uuid.New().String()
}
bifrostCtx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
}
// Populating all user values from the request context
ctx.VisitUserValuesAll(func(key, value any) {
bifrostCtx.SetValue(key, value)
})
// Initialize tags map for collecting maxim tags
maximTags := make(map[string]string)
// Initialize extra headers map for headers prefixed with x-bf-eh-
extraHeaders := make(map[string][]string)
// Initialize extra headers map for headers in the mcp header combined allowlist
mcpExtraHeaders := make(map[string][]string)
// Security denylist of header names that should never be accepted (case-insensitive)
// This denylist is always enforced regardless of user configuration
securityDenylist := map[string]bool{
"proxy-authorization": true,
"cookie": true,
"host": true,
"content-length": true,
"connection": true,
"transfer-encoding": true,
// prevent auth/key overrides via x-bf-eh-*
"x-api-key": true,
"x-goog-api-key": true,
"x-bf-api-key": true,
"x-bf-api-key-id": true,
"x-bf-vk": true,
}
// Debug: Log header matcher state
if logger != nil {
if matcher != nil {
logger.Debug("headerMatcher hasAllowlist=%v, hasDenylist=%v", matcher.HasAllowlist(), matcher.hasDenylist)
} else {
logger.Debug("headerMatcher is nil (allow all)")
}
}
// Then process other headers
ctx.Request.Header.All()(func(key, value []byte) bool {
keyStr := strings.ToLower(string(key))
if keyStr == "baggage" {
if sessionID := ParseSessionIDFromBaggage(string(value)); sessionID != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
}
return true
}
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-prom-"); ok {
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
return true
}
// Checking for maxim headers
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-maxim-"); ok {
switch labelName {
case string(maxim.GenerationIDKey):
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
case string(maxim.TraceIDKey):
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
case string(maxim.SessionIDKey):
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
case string(maxim.TraceNameKey):
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
case string(maxim.GenerationNameKey):
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
case string(maxim.LogRepoIDKey):
bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value))
default:
// apart from these all headers starting with x-bf-maxim- are keys for tags
// collect them in the maximTags map
maximTags[labelName] = string(value)
}
return true
}
// MCP control headers (include-only filtering)
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-mcp-"); ok {
switch labelName {
case "include-clients":
fallthrough
case "include-tools":
// Parse comma-separated values into []string
valueStr := string(value)
var parsedValues []string
if valueStr != "" {
// Split by comma and trim whitespace
for _, v := range strings.Split(valueStr, ",") {
if trimmed := strings.TrimSpace(v); trimmed != "" {
parsedValues = append(parsedValues, trimmed)
}
}
} else {
parsedValues = []string{""}
}
bifrostCtx.SetValue(schemas.BifrostContextKey("mcp-"+labelName), parsedValues)
return true
}
}
// Handle virtual key header (x-bf-vk, authorization, x-api-key, x-goog-api-key headers)
if keyStr == string(schemas.BifrostContextKeyVirtualKey) {
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
return true
}
if keyStr == "authorization" {
valueStr := string(value)
// Only accept Bearer token format: "Bearer ..."
if strings.HasPrefix(strings.ToLower(valueStr), "bearer ") {
authHeaderValue := strings.TrimSpace(valueStr[7:]) // Remove "Bearer " prefix
if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) {
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, authHeaderValue)
return true
}
}
}
if keyStr == "x-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) {
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
return true
}
if keyStr == "x-goog-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) {
bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value))
return true
}
if keyStr == "x-bf-api-key" {
if keyName := strings.TrimSpace(string(value)); keyName != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, keyName)
}
return true
}
if keyStr == "x-bf-api-key-id" {
if keyID := strings.TrimSpace(string(value)); keyID != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, keyID)
}
return true
}
// Handle cache key header (x-bf-cache-key)
if keyStr == "x-bf-cache-key" {
bifrostCtx.SetValue(semanticcache.CacheKey, string(value))
return true
}
// Handle cache TTL header (x-bf-cache-ttl)
if keyStr == "x-bf-cache-ttl" {
valueStr := string(value)
var ttlDuration time.Duration
var err error
// First try to parse as duration (e.g., "30s", "5m", "1h")
if ttlDuration, err = time.ParseDuration(valueStr); err != nil {
// If that fails, try to parse as plain number and treat as seconds
if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 {
ttlDuration = time.Duration(seconds) * time.Second
err = nil // Reset error since we successfully parsed as seconds
}
}
if err == nil {
bifrostCtx.SetValue(semanticcache.CacheTTLKey, ttlDuration)
}
// If both parsing attempts fail, we silently ignore the header and use default TTL
return true
}
// Cache threshold header
if keyStr == "x-bf-cache-threshold" {
threshold, err := strconv.ParseFloat(string(value), 64)
if err == nil {
// Clamp threshold to the inclusive range [0.0, 1.0]
if threshold < 0.0 {
threshold = 0.0
} else if threshold > 1.0 {
threshold = 1.0
}
bifrostCtx.SetValue(semanticcache.CacheThresholdKey, threshold)
}
// If parsing fails, silently ignore the header (no context value set)
return true
}
// Cache type header
if keyStr == "x-bf-cache-type" {
bifrostCtx.SetValue(semanticcache.CacheTypeKey, semanticcache.CacheType(string(value)))
return true
}
// Cache no store header
if keyStr == "x-bf-cache-no-store" {
if valueStr := string(value); valueStr == "true" {
bifrostCtx.SetValue(semanticcache.CacheNoStoreKey, true)
}
return true
}
// Session stickiness: session ID for key binding
if keyStr == "x-bf-session-id" {
if valueStr := strings.TrimSpace(string(value)); valueStr != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeySessionID, valueStr)
}
return true
}
// Session stickiness: per-request TTL override (duration string or seconds integer)
if keyStr == "x-bf-session-ttl" {
valueStr := strings.TrimSpace(string(value))
var ttlDuration time.Duration
var err error
if ttlDuration, err = time.ParseDuration(valueStr); err != nil {
if seconds, parseErr := strconv.Atoi(valueStr); parseErr == nil && seconds > 0 {
ttlDuration = time.Duration(seconds) * time.Second
err = nil
}
}
if err == nil && ttlDuration > 0 {
bifrostCtx.SetValue(schemas.BifrostContextKeySessionTTL, ttlDuration)
}
return true
}
if labelName, ok := strings.CutPrefix(keyStr, "x-bf-eh-"); ok {
// Skip empty header names after prefix removal
if labelName == "" {
return true
}
// Normalize header name to lowercase
labelName = strings.ToLower(labelName)
// Validate against security denylist (always enforced)
if securityDenylist[labelName] {
return true
}
// Apply configurable header filter
if !matcher.ShouldAllow(labelName) {
return true
}
// Append header value (allow multiple values for the same header)
extraHeaders[labelName] = append(extraHeaders[labelName], string(value))
return true
}
// Direct header forwarding: when allowlist is configured, any header explicitly
// in the allowlist can be forwarded directly without the x-bf-eh- prefix.
// This enables forwarding arbitrary headers like "anthropic-beta" directly.
// Only applies when allowlist is non-empty (backward compatible).
if matcher.HasAllowlist() {
if matcher.MatchesAllow(keyStr) {
// Skip reserved x-bf-* headers (handled separately)
if strings.HasPrefix(keyStr, "x-bf-") {
return true
}
// Validate against security denylist (always enforced)
if securityDenylist[keyStr] {
return true
}
// Check denylist
if matcher.MatchesDeny(keyStr) {
return true
}
// Forward the header directly with its original name
if logger != nil {
logger.Debug("forwarding header via allowlist: %s", keyStr)
}
extraHeaders[keyStr] = append(extraHeaders[keyStr], string(value))
return true
}
}
// Handle MCP extra headers
if mcpHeaderCombinedAllowlist.IsAllowed(keyStr) {
mcpExtraHeaders[keyStr] = append(mcpExtraHeaders[keyStr], string(value))
return true
}
// Raw capture headers — all three support "true"/"false" to fully override the
// provider-level config for this request.
if keyStr == "x-bf-send-back-raw-request" {
if b, err := strconv.ParseBool(string(value)); err == nil {
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawRequest, b)
}
return true
}
if keyStr == "x-bf-send-back-raw-response" {
if b, err := strconv.ParseBool(string(value)); err == nil {
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawResponse, b)
}
return true
}
if keyStr == "x-bf-store-raw-request-response" {
if b, err := strconv.ParseBool(string(value)); err == nil {
bifrostCtx.SetValue(schemas.BifrostContextKeyStoreRawRequestResponse, b)
}
return true
}
// Parent request ID header (for linking MCP tool calls to parent LLM requests)
if keyStr == "x-bf-parent-request-id" {
if valueStr := strings.TrimSpace(string(value)); valueStr != "" {
bifrostCtx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, valueStr)
}
return true
}
// Add passthrough extra params header support
if keyStr == "x-bf-passthrough-extra-params" {
if valueStr := string(value); valueStr == "true" {
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
}
return true
}
// Compat header: per-request override of compat plugin settings.
// Accepts: "true" (enable all), JSON array of feature names, or ["*"] (enable all).
// An empty array [] or absent header means no overrides.
if keyStr == "x-bf-compat" {
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertTextToChat)
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertChatToResponses)
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldDropParams)
bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldConvertParams)
valueStr := strings.TrimSpace(string(value))
if valueStr == "true" {
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
} else if strings.HasPrefix(valueStr, "[") {
var features []string
if err := json.Unmarshal([]byte(valueStr), &features); err == nil {
if len(features) == 1 && features[0] == "*" {
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
} else {
for _, f := range features {
switch f {
case "convert_text_to_chat":
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true)
case "convert_chat_to_responses":
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true)
case "should_drop_params":
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true)
case "should_convert_params":
bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true)
}
}
}
}
}
return true
}
return true
})
// Store the collected maxim tags in the context
if len(maximTags) > 0 {
bifrostCtx.SetValue(schemas.BifrostContextKey(maxim.TagsKey), maximTags)
}
// Store collected extra headers in the context if any were found
if len(extraHeaders) > 0 {
bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders)
}
// Store collected MCP extra headers in the context if any were found
if len(mcpExtraHeaders) > 0 {
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPExtraHeaders, mcpExtraHeaders)
}
// Collect all request headers for downstream use (e.g., governance required headers check)
// Keys are lowercased for case-insensitive lookup
allHeaders := make(map[string]string)
ctx.Request.Header.All()(func(key, value []byte) bool {
allHeaders[strings.ToLower(string(key))] = string(value)
return true
})
bifrostCtx.SetValue(schemas.BifrostContextKeyRequestHeaders, allHeaders)
// Extract per-user MCP OAuth user identifier from X-Bf-User-Id header
if mcpUserID := string(ctx.Request.Header.Peek("X-Bf-User-Id")); mcpUserID != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserID, mcpUserID)
}
// Build and set OAuth redirect URI for per-user OAuth flows
scheme := "http"
if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" {
scheme = "https"
}
host := string(ctx.Host())
if host != "" {
bifrostCtx.SetValue(schemas.BifrostContextKeyOAuthRedirectURI, fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host))
}
if allowDirectKeys {
// Extract API key from Authorization header (Bearer format), x-api-key, or x-goog-api-key header
var apiKey string
// TODO: fix plugin data leak
// Check Authorization header (Bearer format only - OpenAI style)
authHeader := string(ctx.Request.Header.Peek("Authorization"))
if authHeader != "" {
// Only accept Bearer token format: "Bearer ..."
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix
if authHeaderValue != "" && !strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) {
apiKey = authHeaderValue
}
} else {
apiKey = authHeader
}
}
if apiKey == "" {
// Check x-api-key (Anthropic style) header if no valid Authorization header found
xAPIKey := string(ctx.Request.Header.Peek("x-api-key"))
if xAPIKey != "" && !strings.HasPrefix(strings.ToLower(xAPIKey), governance.VirtualKeyPrefix) {
apiKey = strings.TrimSpace(xAPIKey)
} else {
// Check x-goog-api-key (Google Gemini style) header if no valid Authorization header found
xGoogleAPIKey := string(ctx.Request.Header.Peek("x-goog-api-key"))
if xGoogleAPIKey != "" && !strings.HasPrefix(strings.ToLower(xGoogleAPIKey), governance.VirtualKeyPrefix) {
apiKey = strings.TrimSpace(xGoogleAPIKey)
}
}
}
// If we found an API key, create a Key object and store it in context
if apiKey != "" {
key := schemas.Key{
ID: "header-provided", // Identifier for header-provided keys
Value: *schemas.NewEnvVar(apiKey),
Models: schemas.WhiteList{"*"}, // Allow all models
Weight: 1.0, // Default weight
}
bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
return bifrostCtx, cancel
}
// BuildHTTPRequestFromFastHTTP creates an HTTPRequest from fasthttp context for streaming handlers.
// The returned request should be released with schemas.ReleaseHTTPRequest when done.
// Note: Body is not copied for streaming (body was already consumed for the request).
func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest {
req := schemas.AcquireHTTPRequest()
req.Method = string(ctx.Method())
req.Path = string(ctx.Path())
// Copy headers
for key, value := range ctx.Request.Header.All() {
req.Headers[string(key)] = string(value)
}
// Copy query params
for key, value := range ctx.Request.URI().QueryArgs().All() {
req.Query[string(key)] = string(value)
}
// Copy path parameters from user values
ctx.VisitUserValuesAll(func(key, value any) {
keyStr, keyIsString := key.(string)
valueStr, valueIsString := value.(string)
if !keyIsString || !valueIsString {
return
}
if strings.HasPrefix(keyStr, "bifrost-") ||
keyStr == "BifrostContextKeyRequestID" ||
keyStr == "trace_id" ||
keyStr == "span_id" {
return
}
req.PathParams[keyStr] = valueStr
})
// Note: Body not copied - for streaming, body was already consumed
return req
}
// BuildHTTPResponseFromFastHTTP creates an HTTPResponse snapshot from fasthttp context.
// Only captures status code and headers — body is skipped because for streaming
// responses it is an active io.Reader that cannot be materialized.
// The returned response should be released with schemas.ReleaseHTTPResponse when done.
func BuildHTTPResponseFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPResponse {
resp := schemas.AcquireHTTPResponse()
resp.StatusCode = ctx.Response.StatusCode()
for key, value := range ctx.Response.Header.All() {
resp.Headers[string(key)] = string(value)
}
return resp
}

View File

@@ -0,0 +1,275 @@
package lib
import (
"context"
"testing"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func TestParseSessionIDFromBaggage(t *testing.T) {
tests := []struct {
name string
header string
want string
}{
{name: "single member", header: "session-id=abc", want: "abc"},
{name: "multiple members", header: "foo=bar, session-id=abc, baz=qux", want: "abc"},
{name: "member with properties", header: "session-id=abc;ttl=60", want: "abc"},
{name: "spaces preserved around parsing", header: " foo=bar , session-id = abc123 ;ttl=60 ", want: "abc123"},
{name: "missing member", header: "foo=bar", want: ""},
{name: "malformed ignored", header: "session-id, foo=bar", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseSessionIDFromBaggage(tt.header); got != tt.want {
t.Fatalf("ParseSessionIDFromBaggage(%q) = %q, want %q", tt.header, got, tt.want)
}
})
}
}
func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
base := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
base.SetValue(schemas.BifrostContextKeyRequestID, "req-shared")
ctx.SetUserValue(FastHTTPUserValueBifrostContext, base)
converted, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
defer cancel()
if converted == nil {
t.Fatal("expected non-nil converted context")
}
if got, _ := converted.Value(schemas.BifrostContextKeyRequestID).(string); got != "req-shared" {
t.Fatalf("expected converted context to preserve parent values, got request-id=%q", got)
}
if stored, ok := ctx.UserValue(FastHTTPUserValueBifrostContext).(*schemas.BifrostContext); !ok || stored == nil {
t.Fatal("expected shared context pointer to be stored on fasthttp user values")
}
if ctx.UserValue(FastHTTPUserValueBifrostCancel) == nil {
t.Fatal("expected shared cancel function to be stored on fasthttp user values")
}
}
func TestConvertToBifrostContext_SecondCallReturnsSameSharedContext(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
first, cancelFirst := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
defer cancelFirst()
if first == nil {
t.Fatal("expected first context to be non-nil")
}
second, cancelSecond := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
defer cancelSecond()
if second == nil {
t.Fatal("expected second context to be non-nil")
}
if first != second {
t.Fatal("expected ConvertToBifrostContext to reuse the shared context on repeated calls")
}
}
// TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked verifies that
// even with a "*" allowlist (allow all), the hardcoded security denylist in
// ConvertToBifrostContext still blocks security-sensitive headers.
func TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked(t *testing.T) {
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"*"},
})
ctx := &fasthttp.RequestCtx{}
// x-bf-eh-* prefixed headers
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
ctx.Request.Header.Set("x-bf-eh-cookie", "should-be-blocked")
ctx.Request.Header.Set("x-bf-eh-x-api-key", "should-be-blocked")
ctx.Request.Header.Set("x-bf-eh-host", "should-be-blocked")
ctx.Request.Header.Set("x-bf-eh-connection", "should-be-blocked")
ctx.Request.Header.Set("x-bf-eh-proxy-authorization", "should-be-blocked")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
defer cancel()
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
// custom-header should be forwarded
if _, ok := extraHeaders["custom-header"]; !ok {
t.Error("expected custom-header to be forwarded via x-bf-eh- prefix")
}
// Security headers should be blocked even with * allowlist
securityHeaders := []string{"cookie", "x-api-key", "host", "connection", "proxy-authorization"}
for _, h := range securityHeaders {
if _, ok := extraHeaders[h]; ok {
t.Errorf("expected security header %q to be blocked even with * allowlist", h)
}
}
}
// TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked verifies
// that direct header forwarding with "*" allowlist forwards non-security headers
// but still blocks security headers.
func TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked(t *testing.T) {
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"*"},
})
ctx := &fasthttp.RequestCtx{}
// Direct headers (not prefixed with x-bf-eh-)
ctx.Request.Header.Set("custom-header", "allowed-value")
ctx.Request.Header.Set("anthropic-beta", "some-beta-feature")
// Security headers sent directly — should be blocked
ctx.Request.Header.Set("proxy-authorization", "should-be-blocked")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
defer cancel()
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
// Direct non-security headers should be forwarded when allowlist has *
if _, ok := extraHeaders["custom-header"]; !ok {
t.Error("expected custom-header to be forwarded directly")
}
if _, ok := extraHeaders["anthropic-beta"]; !ok {
t.Error("expected anthropic-beta to be forwarded directly")
}
// Security headers should still be blocked in direct forwarding path
directSecurityHeaders := []string{"proxy-authorization", "cookie", "host", "connection"}
for _, h := range directSecurityHeaders {
if _, ok := extraHeaders[h]; ok {
t.Errorf("expected security header %q to be blocked in direct forwarding even with * allowlist", h)
}
}
}
// TestConvertToBifrostContext_PrefixWildcardDirectForwarding verifies that
// prefix wildcard patterns like "anthropic-*" work for direct header forwarding
// (without x-bf-eh- prefix).
func TestConvertToBifrostContext_PrefixWildcardDirectForwarding(t *testing.T) {
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-*"},
})
ctx := &fasthttp.RequestCtx{}
// Direct headers matching the wildcard pattern
ctx.Request.Header.Set("anthropic-beta", "beta-value")
ctx.Request.Header.Set("anthropic-version", "2024-01-01")
// Header not matching the pattern
ctx.Request.Header.Set("openai-version", "should-not-forward")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
defer cancel()
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
if _, ok := extraHeaders["anthropic-beta"]; !ok {
t.Error("expected anthropic-beta to be forwarded directly via wildcard allowlist")
}
if _, ok := extraHeaders["anthropic-version"]; !ok {
t.Error("expected anthropic-version to be forwarded directly via wildcard allowlist")
}
if _, ok := extraHeaders["openai-version"]; ok {
t.Error("expected openai-version to NOT be forwarded (doesn't match anthropic-*)")
}
}
// TestConvertToBifrostContext_WildcardAllowlistFiltering verifies wildcard patterns
// correctly filter headers via the x-bf-eh- prefix path.
func TestConvertToBifrostContext_WildcardAllowlistFiltering(t *testing.T) {
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-*"},
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("x-bf-eh-anthropic-beta", "beta-value")
ctx.Request.Header.Set("x-bf-eh-anthropic-version", "2024-01-01")
ctx.Request.Header.Set("x-bf-eh-openai-version", "should-be-blocked")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
defer cancel()
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
if _, ok := extraHeaders["anthropic-beta"]; !ok {
t.Error("expected anthropic-beta to be forwarded")
}
if _, ok := extraHeaders["anthropic-version"]; !ok {
t.Error("expected anthropic-version to be forwarded")
}
if _, ok := extraHeaders["openai-version"]; ok {
t.Error("expected openai-version to be blocked (not matching anthropic-*)")
}
}
// TestConvertToBifrostContext_WildcardDenylistBlocking verifies wildcard denylist
// patterns block matching headers.
func TestConvertToBifrostContext_WildcardDenylistBlocking(t *testing.T) {
matcher := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"x-internal-*"},
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("x-bf-eh-x-internal-id", "blocked-value")
ctx.Request.Header.Set("x-bf-eh-x-internal-secret", "blocked-value")
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{})
defer cancel()
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
if _, ok := extraHeaders["x-internal-id"]; ok {
t.Error("expected x-internal-id to be blocked by denylist")
}
if _, ok := extraHeaders["x-internal-secret"]; ok {
t.Error("expected x-internal-secret to be blocked by denylist")
}
if _, ok := extraHeaders["custom-header"]; !ok {
t.Error("expected custom-header to be forwarded")
}
}
// TestConvertToBifrostContext_NilMatcher verifies nil matcher allows all headers.
func TestConvertToBifrostContext_NilMatcher(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
defer cancel()
extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
if _, ok := extraHeaders["custom-header"]; !ok {
t.Error("expected custom-header to be forwarded with nil matcher")
}
}
func TestConvertToBifrostContext_BaggageSessionIDSetsGrouping(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("baggage", "foo=bar, session-id=rt-123, baz=qux")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
defer cancel()
if got, _ := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-123" {
t.Fatalf("parent request id = %q, want %q", got, "rt-123")
}
}
func TestConvertToBifrostContext_EmptyBaggageSessionIDIgnored(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("baggage", "session-id= ")
bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{})
defer cancel()
if got := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID); got != nil {
t.Fatalf("parent request id should be unset, got %#v", got)
}
}

View File

@@ -0,0 +1,6 @@
package lib
import "errors"
var ErrNotFound = errors.New("not found")
var ErrAlreadyExists = errors.New("already exists")

View File

@@ -0,0 +1,136 @@
package lib
import (
"strings"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
// HeaderMatchesPattern returns true if headerName matches the pattern.
// Patterns support trailing wildcard: "anthropic-*" matches "anthropic-beta".
// A bare "*" matches everything. All comparisons are case-insensitive.
func HeaderMatchesPattern(pattern, headerName string) bool {
pattern = strings.ToLower(strings.TrimSpace(pattern))
headerName = strings.ToLower(strings.TrimSpace(headerName))
if pattern == "*" {
return true
}
if strings.HasSuffix(pattern, "*") {
return strings.HasPrefix(headerName, pattern[:len(pattern)-1])
}
return pattern == headerName
}
// HeaderMatcher holds precomputed header filter data for O(1) exact-match lookups
// and fast prefix matching. Compiled once on config change, safe for concurrent reads.
type HeaderMatcher struct {
allowExact map[string]bool
allowPrefixes []string // lowercased prefixes (without trailing *)
allowAll bool
hasAllowlist bool
denyExact map[string]bool
denyPrefixes []string
denyAll bool
hasDenylist bool
}
// NewHeaderMatcher compiles a GlobalHeaderFilterConfig into an optimized HeaderMatcher.
// Returns nil if config is nil (callers should treat nil as "allow all").
func NewHeaderMatcher(config *configstoreTables.GlobalHeaderFilterConfig) *HeaderMatcher {
if config == nil {
return nil
}
m := &HeaderMatcher{
allowExact: make(map[string]bool, len(config.Allowlist)),
denyExact: make(map[string]bool, len(config.Denylist)),
}
for _, p := range config.Allowlist {
lp := strings.ToLower(strings.TrimSpace(p))
if lp == "" {
continue
}
if lp == "*" {
m.allowAll = true
} else if strings.HasSuffix(lp, "*") {
m.allowPrefixes = append(m.allowPrefixes, lp[:len(lp)-1])
} else {
m.allowExact[lp] = true
}
}
for _, p := range config.Denylist {
lp := strings.ToLower(strings.TrimSpace(p))
if lp == "" {
continue
}
if lp == "*" {
m.denyAll = true
} else if strings.HasSuffix(lp, "*") {
m.denyPrefixes = append(m.denyPrefixes, lp[:len(lp)-1])
} else {
m.denyExact[lp] = true
}
}
m.hasAllowlist = m.allowAll || len(m.allowExact) > 0 || len(m.allowPrefixes) > 0
m.hasDenylist = m.denyAll || len(m.denyExact) > 0 || len(m.denyPrefixes) > 0
return m
}
// HasAllowlist returns true if the matcher has a non-empty allowlist.
func (m *HeaderMatcher) HasAllowlist() bool {
if m == nil {
return false
}
return m.hasAllowlist
}
// MatchesAllow returns true if headerName matches any allowlist entry.
// headerName must be lowercased by the caller.
func (m *HeaderMatcher) MatchesAllow(headerName string) bool {
if m.allowAll {
return true
}
if m.allowExact[headerName] {
return true
}
for _, prefix := range m.allowPrefixes {
if strings.HasPrefix(headerName, prefix) {
return true
}
}
return false
}
// MatchesDeny returns true if headerName matches any denylist entry.
// headerName must be lowercased by the caller.
func (m *HeaderMatcher) MatchesDeny(headerName string) bool {
if m.denyAll {
return true
}
if m.denyExact[headerName] {
return true
}
for _, prefix := range m.denyPrefixes {
if strings.HasPrefix(headerName, prefix) {
return true
}
}
return false
}
// ShouldAllow determines if a header should be forwarded based on the
// configurable header filter config (separate from the security denylist).
// Returns true if the header passes both allowlist and denylist checks.
// headerName is lowercased internally for case-insensitive matching.
func (m *HeaderMatcher) ShouldAllow(headerName string) bool {
if m == nil {
return true
}
headerName = strings.ToLower(headerName)
if m.hasAllowlist && !m.MatchesAllow(headerName) {
return false
}
if m.hasDenylist && m.MatchesDeny(headerName) {
return false
}
return true
}

View File

@@ -0,0 +1,251 @@
package lib
import (
"testing"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
)
func TestHeaderMatchesPattern(t *testing.T) {
tests := []struct {
pattern string
headerName string
want bool
}{
// Exact match
{"anthropic-beta", "anthropic-beta", true},
{"anthropic-beta", "anthropic-alpha", false},
// Case insensitive exact match
{"Anthropic-Beta", "anthropic-beta", true},
{"anthropic-beta", "Anthropic-Beta", true},
// Star matches all
{"*", "anything", true},
{"*", "", true},
// Prefix wildcard
{"anthropic-*", "anthropic-beta", true},
{"anthropic-*", "anthropic-version", true},
{"anthropic-*", "anthropic-", true},
{"anthropic-*", "openai-version", false},
{"anthropic-*", "anthropic", false},
// Case insensitive prefix wildcard
{"Anthropic-*", "anthropic-beta", true},
{"anthropic-*", "Anthropic-Beta", true},
// No match
{"foo", "bar", false},
{"", "foo", false},
// Pattern without wildcard doesn't prefix match
{"anthropic-", "anthropic-beta", false},
}
for _, tt := range tests {
t.Run(tt.pattern+"_"+tt.headerName, func(t *testing.T) {
got := HeaderMatchesPattern(tt.pattern, tt.headerName)
if got != tt.want {
t.Errorf("HeaderMatchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.headerName, got, tt.want)
}
})
}
}
func TestNewHeaderMatcher_Nil(t *testing.T) {
m := NewHeaderMatcher(nil)
if m != nil {
t.Fatal("expected nil matcher for nil config")
}
// nil matcher should allow everything
if !m.ShouldAllow("anything") {
t.Error("nil matcher should allow all headers")
}
if m.HasAllowlist() {
t.Error("nil matcher should have no allowlist")
}
}
func TestNewHeaderMatcher_Empty(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{})
if m == nil {
t.Fatal("expected non-nil matcher for empty config")
}
if m.HasAllowlist() {
t.Error("empty config should have no allowlist")
}
if !m.ShouldAllow("anything") {
t.Error("empty config should allow all headers")
}
}
func TestHeaderMatcher_ExactAllowlist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-beta", "custom-id"},
})
if !m.ShouldAllow("anthropic-beta") {
t.Error("should allow anthropic-beta")
}
if !m.ShouldAllow("custom-id") {
t.Error("should allow custom-id")
}
if m.ShouldAllow("openai-version") {
t.Error("should not allow openai-version")
}
// Case insensitive
if !m.ShouldAllow("Anthropic-Beta") {
t.Error("should allow Anthropic-Beta (case insensitive)")
}
}
func TestHeaderMatcher_WildcardAllowlist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-*"},
})
if !m.ShouldAllow("anthropic-beta") {
t.Error("should allow anthropic-beta")
}
if !m.ShouldAllow("anthropic-version") {
t.Error("should allow anthropic-version")
}
if m.ShouldAllow("openai-version") {
t.Error("should not allow openai-version")
}
}
func TestHeaderMatcher_StarAllowlist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"*"},
})
if !m.ShouldAllow("anything") {
t.Error("* should allow anything")
}
if !m.ShouldAllow("") {
t.Error("* should allow empty string")
}
}
func TestHeaderMatcher_ExactDenylist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"secret-token"},
})
if m.ShouldAllow("secret-token") {
t.Error("should deny secret-token")
}
if !m.ShouldAllow("public-key") {
t.Error("should allow public-key")
}
}
func TestHeaderMatcher_WildcardDenylist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"x-internal-*"},
})
if m.ShouldAllow("x-internal-id") {
t.Error("should deny x-internal-id")
}
if m.ShouldAllow("x-internal-secret") {
t.Error("should deny x-internal-secret")
}
if !m.ShouldAllow("x-external-id") {
t.Error("should allow x-external-id")
}
}
func TestHeaderMatcher_StarDenylist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"*"},
})
if m.ShouldAllow("anything") {
t.Error("* denylist should deny everything")
}
}
func TestHeaderMatcher_AllowlistWithDenylist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"*"},
Denylist: []string{"x-internal-*"},
})
if !m.ShouldAllow("anthropic-beta") {
t.Error("should allow anthropic-beta")
}
if m.ShouldAllow("x-internal-id") {
t.Error("should deny x-internal-id (denylist overrides)")
}
}
func TestHeaderMatcher_AllowlistPrefixWithDenylistExact(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-*"},
Denylist: []string{"anthropic-dangerous"},
})
if !m.ShouldAllow("anthropic-beta") {
t.Error("should allow anthropic-beta")
}
if m.ShouldAllow("anthropic-dangerous") {
t.Error("should deny anthropic-dangerous")
}
if m.ShouldAllow("openai-version") {
t.Error("should not allow openai-version (not in allowlist)")
}
}
func TestHeaderMatcher_CaseInsensitive(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"Anthropic-*"},
Denylist: []string{"X-Internal-*"},
})
if !m.ShouldAllow("anthropic-beta") {
t.Error("should allow anthropic-beta (case insensitive)")
}
if m.ShouldAllow("x-internal-id") {
t.Error("should deny x-internal-id (case insensitive)")
}
}
func TestHeaderMatcher_MatchesAllow(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"anthropic-*", "custom-id"},
})
if !m.MatchesAllow("anthropic-beta") {
t.Error("should match anthropic-beta")
}
if !m.MatchesAllow("custom-id") {
t.Error("should match custom-id")
}
if m.MatchesAllow("openai-version") {
t.Error("should not match openai-version")
}
}
func TestHeaderMatcher_MatchesDeny(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"secret-*", "blocked"},
})
if !m.MatchesDeny("secret-token") {
t.Error("should match secret-token")
}
if !m.MatchesDeny("blocked") {
t.Error("should match blocked")
}
if m.MatchesDeny("allowed") {
t.Error("should not match allowed")
}
}
func TestHeaderMatcher_HasAllowlist(t *testing.T) {
m := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Allowlist: []string{"foo"},
})
if !m.HasAllowlist() {
t.Error("should have allowlist")
}
m2 := NewHeaderMatcher(&configstoreTables.GlobalHeaderFilterConfig{
Denylist: []string{"bar"},
})
if m2.HasAllowlist() {
t.Error("should not have allowlist")
}
}

View File

@@ -0,0 +1,58 @@
package lib
import (
"io"
"strconv"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
var logger schemas.Logger
// SetLogger sets the logger for the application.
func SetLogger(l schemas.Logger) {
logger = l
}
// StreamLargeResponseBody extracts the large response reader from context and streams
// it directly to the client. Sets status 200, content-type, and content-length headers.
// Returns false if the reader is not available (caller should send an error response).
func StreamLargeResponseBody(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
if bifrostCtx == nil {
return false
}
reader, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseReader).(io.ReadCloser)
if !ok || reader == nil {
return false
}
contentLength, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentLength).(int)
contentType, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentType).(string)
contentDisposition, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseContentDisposition).(string)
// Mirror large-response-mode to fasthttp UserValue so post-hook middleware
// (which only sees ctx.UserValue, not bifrostCtx) can skip body materialization.
ctx.SetUserValue(FastHTTPUserValueLargeResponseMode, true)
ctx.SetStatusCode(fasthttp.StatusOK)
if contentType != "" {
ctx.SetContentType(contentType)
} else {
ctx.SetContentType("application/json")
}
if contentDisposition != "" {
ctx.Response.Header.Set("Content-Disposition", contentDisposition)
}
// bodySize for SetBodyStream: positive = known size, -1 = unknown (read until EOF).
// fasthttp treats 0 as "known empty", so default to -1 when CL is unavailable.
bodySize := contentLength
if bodySize > 0 {
ctx.Response.Header.Set("Content-Length", strconv.Itoa(contentLength))
} else {
bodySize = -1
}
ctx.Response.SetBodyStream(reader, bodySize)
return true
}

View File

@@ -0,0 +1,23 @@
package lib
import (
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ChainMiddlewares chains multiple middlewares together
// Middlewares are applied in order: the first middleware wraps the second, etc.
// This allows earlier middlewares to short-circuit by not calling next(ctx)
func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...schemas.BifrostHTTPMiddleware) fasthttp.RequestHandler {
// If no middlewares, return the original handler
if len(middlewares) == 0 {
return handler
}
// Build the chain from right to left (last middleware wraps the handler)
// This ensures execution order is left to right (first middleware executes first)
chained := handler
for i := len(middlewares) - 1; i >= 0; i-- {
chained = middlewares[i](chained)
}
return chained
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,174 @@
package lib
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/plugins/semanticcache"
"github.com/stretchr/testify/require"
)
func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"dimension": 1,
"ttl": "5m",
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.NoError(t, err)
configMap, ok := pluginConfig.Config.(map[string]interface{})
require.True(t, ok)
_, hasKeys := configMap["keys"]
require.False(t, hasKeys, "direct-only mode should not inject provider keys")
}
func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"dimension": 1,
"keys": []schemas.Key{{Name: "stale-key"}},
"embedding_model": "text-embedding-3-small",
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.NoError(t, err)
configMap, ok := pluginConfig.Config.(map[string]interface{})
require.True(t, ok)
_, hasKeys := configMap["keys"]
require.False(t, hasKeys, "direct-only mode should remove stale provider keys")
_, hasEmbeddingModel := configMap["embedding_model"]
require.False(t, hasEmbeddingModel, "direct-only mode should remove stale embedding_model")
}
func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) {
config := &Config{
Providers: map[schemas.ModelProvider]configstore.ProviderConfig{
schemas.OpenAI: {
Keys: []schemas.Key{
{
Name: "openai-key",
Value: *schemas.NewEnvVar("sk-test"),
Weight: 1,
},
},
},
},
}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"provider": "openai",
"embedding_model": "text-embedding-3-small",
"dimension": 1536,
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.NoError(t, err)
configMap, ok := pluginConfig.Config.(map[string]interface{})
require.True(t, ok)
keys, ok := configMap["keys"].([]schemas.Key)
require.True(t, ok, "provider-backed mode should inject provider keys")
require.Len(t, keys, 1)
require.Equal(t, "openai-key", keys[0].Name)
require.Equal(t, "openai", configMap["provider"])
}
func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"dimension": 1536,
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.Error(t, err)
require.Contains(t, err.Error(), "requires 'provider' for semantic mode")
}
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"provider": "openai",
"embedding_model": "text-embedding-3-small",
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.Error(t, err)
require.Contains(t, err.Error(), "requires 'dimension' for provider-backed semantic mode")
}
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"provider": "openai",
"embedding_model": "text-embedding-3-small",
"dimension": 1,
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.Error(t, err)
require.Contains(t, err.Error(), "requires 'dimension' > 1")
}
func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"provider": "openai",
"dimension": 1536,
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.Error(t, err)
require.Contains(t, err.Error(), "requires 'embedding_model'")
}
func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"dimension": 0,
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.Error(t, err)
require.Contains(t, err.Error(), "'dimension' must be >= 1")
}
func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) {
config := &Config{}
pluginConfig := &schemas.PluginConfig{
Name: semanticcache.PluginName,
Config: map[string]interface{}{
"dimension": -1,
},
}
err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig)
require.Error(t, err)
require.Contains(t, err.Error(), "'dimension' must be >= 1")
}

View File

@@ -0,0 +1,114 @@
package lib
import (
"io"
"sync"
)
// SSEStreamReader is an io.ReadCloser that delivers one event per Read call,
// bypassing fasthttp's internal pipe mechanism (fasthttputil.PipeConns) which
// batches multiple events into single TCP segments.
//
// Usage:
// 1. Create with NewSSEStreamReader()
// 2. Pass to ctx.Response.SetBodyStream(reader, -1)
// 3. Start a producer goroutine that calls Send()/SendEvent()/SendError() for each event
// 4. Producer calls Done() when finished (closes the event channel)
// 5. fasthttp calls Close() on write errors (signals producer to stop)
type SSEStreamReader struct {
eventCh chan []byte
closeCh chan struct{}
closeOnce sync.Once
current []byte // remaining bytes from a partial read
}
// NewSSEStreamReader creates a new SSEStreamReader with a buffered event channel.
// Channel capacity of 1 allows one event of pipeline parallelism between
// the producer goroutine and fasthttp's writeBodyChunked loop.
func NewSSEStreamReader() *SSEStreamReader {
return &SSEStreamReader{
eventCh: make(chan []byte, 1),
closeCh: make(chan struct{}),
}
}
// Read implements io.Reader. It blocks until an event is available, then returns
// that event's bytes. If the caller's buffer is smaller than the event, remaining
// bytes are stored and returned on subsequent calls. Returns io.EOF when Done()
// has been called and all events have been consumed.
func (r *SSEStreamReader) Read(p []byte) (int, error) {
if len(r.current) == 0 {
event, ok := <-r.eventCh
if !ok {
return 0, io.EOF
}
r.current = event
}
n := copy(p, r.current)
r.current = r.current[n:]
return n, nil
}
// Close implements io.Closer. Called by fasthttp when writeBodyChunked encounters
// a write error (client disconnect). Signals the producer goroutine to stop via closeCh.
// Safe to call multiple times.
func (r *SSEStreamReader) Close() error {
r.closeOnce.Do(func() {
close(r.closeCh)
})
return nil
}
// Send delivers a pre-formatted event to the reader. Returns false if the reader
// has been closed (client disconnected), in which case the producer should stop.
func (r *SSEStreamReader) Send(event []byte) bool {
// Check closeCh first (non-blocking) to avoid sending after Close
select {
case <-r.closeCh:
return false
default:
}
select {
case r.eventCh <- event:
return true
case <-r.closeCh:
return false
}
}
// SendEvent sends an SSE-framed event. If eventType is empty, it sends "data: <data>\n\n".
// If eventType is non-empty, it sends "event: <eventType>\ndata: <data>\n\n".
// Returns false if the reader has been closed (client disconnected).
func (r *SSEStreamReader) SendEvent(eventType string, data []byte) bool {
var buf []byte
if eventType != "" {
buf = make([]byte, 0, 7+len(eventType)+7+len(data)+2)
buf = append(buf, "event: "...)
buf = append(buf, eventType...)
buf = append(buf, "\ndata: "...)
} else {
buf = make([]byte, 0, 6+len(data)+2)
buf = append(buf, "data: "...)
}
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return r.Send(buf)
}
// SendError sends an SSE error event: "event: error\ndata: <data>\n\n".
// Returns false if the reader has been closed (client disconnected).
func (r *SSEStreamReader) SendError(data []byte) bool {
return r.SendEvent("error", data)
}
// SendDone sends the standard SSE done marker: "data: [DONE]\n\n".
// Returns false if the reader has been closed (client disconnected).
func (r *SSEStreamReader) SendDone() bool {
return r.Send([]byte("data: [DONE]\n\n"))
}
// Done closes the event channel, signaling to Read that the stream is finished.
// Must be called exactly once by the producer goroutine when streaming is complete.
func (r *SSEStreamReader) Done() {
close(r.eventCh)
}

View File

@@ -0,0 +1,714 @@
package lib
import (
"fmt"
"io"
"sync"
"testing"
)
func TestSSEStreamReaderSingleEventPerRead(t *testing.T) {
r := NewSSEStreamReader()
events := [][]byte{
[]byte("data: {\"chunk\":1}\n\n"),
[]byte("data: {\"chunk\":2}\n\n"),
[]byte("data: {\"chunk\":3}\n\n"),
}
errCh := make(chan error, 1)
go func() {
for _, e := range events {
if !r.Send(e) {
select {
case errCh <- fmt.Errorf("Send returned false unexpectedly"):
default:
}
return
}
}
r.Done()
}()
buf := make([]byte, 4096)
for i, want := range events {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
got := string(buf[:n])
if got != string(want) {
t.Errorf("event %d: got %q, want %q", i, got, want)
}
}
// Next read should return EOF
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
}
select {
case err := <-errCh:
t.Error(err)
default:
}
}
func TestSSEStreamReaderPartialRead(t *testing.T) {
r := NewSSEStreamReader()
event := []byte("data: {\"content\":\"hello world\"}\n\n")
go func() {
r.Send(event)
r.Done()
}()
// Read with a small buffer (5 bytes at a time)
var result []byte
buf := make([]byte, 5)
for {
n, err := r.Read(buf)
result = append(result, buf[:n]...)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
if string(result) != string(event) {
t.Errorf("reassembled data: got %q, want %q", result, event)
}
}
func TestSSEStreamReaderEOFOnDone(t *testing.T) {
r := NewSSEStreamReader()
r.Done() // Close immediately
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
}
}
func TestSSEStreamReaderCloseSignalsProducer(t *testing.T) {
r := NewSSEStreamReader()
r.Close()
if r.Send([]byte("data: test\n\n")) {
t.Error("Send should return false after Close")
}
}
func TestSSEStreamReaderIdempotentClose(t *testing.T) {
r := NewSSEStreamReader()
// Should not panic
r.Close()
r.Close()
r.Close()
}
func TestSSEStreamReaderConcurrent(t *testing.T) {
r := NewSSEStreamReader()
const numEvents = 100
var wg sync.WaitGroup
wg.Add(1)
// Producer
go func() {
for i := 0; i < numEvents; i++ {
if !r.Send([]byte("data: event\n\n")) {
break
}
}
r.Done()
}()
// Consumer
errCh := make(chan error, 2)
go func() {
defer wg.Done()
buf := make([]byte, 4096)
count := 0
for {
_, err := r.Read(buf)
if err == io.EOF {
break
}
if err != nil {
select {
case errCh <- fmt.Errorf("unexpected error: %v", err):
default:
}
break
}
count++
}
if count != numEvents {
select {
case errCh <- fmt.Errorf("got %d events, want %d", count, numEvents):
default:
}
}
}()
wg.Wait()
close(errCh)
for err := range errCh {
t.Error(err)
}
}
func TestSSEStreamReaderSendEvent(t *testing.T) {
tests := []struct {
name string
eventType string
data []byte
want string
}{
{
name: "data only",
eventType: "",
data: []byte(`{"chunk":1}`),
want: "data: {\"chunk\":1}\n\n",
},
{
name: "with event type",
eventType: "response.delta",
data: []byte(`{"delta":"hi"}`),
want: "event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent(tt.eventType, tt.data)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := string(buf[:n]); got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
func TestSSEStreamReaderSendError(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendError([]byte(`{"error":"bad"}`))
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "event: error\ndata: {\"error\":\"bad\"}\n\n"
if got := string(buf[:n]); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSSEStreamReaderSendDone(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendDone()
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "data: [DONE]\n\n"
if got := string(buf[:n]); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSSEStreamReaderSendEventAfterClose(t *testing.T) {
r := NewSSEStreamReader()
r.Close()
if r.SendEvent("test", []byte("data")) {
t.Error("SendEvent should return false after Close")
}
if r.SendError([]byte("err")) {
t.Error("SendError should return false after Close")
}
if r.SendDone() {
t.Error("SendDone should return false after Close")
}
}
// TestSSEStreamReaderSendEventByteAccuracy verifies that SendEvent produces
// the exact same bytes that the old manual buffer assembly in the handlers did.
func TestSSEStreamReaderSendEventByteAccuracy(t *testing.T) {
tests := []struct {
name string
eventType string
data []byte
want []byte
}{
{
name: "standard SSE data (old inference.go pattern)",
eventType: "",
data: []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`),
want: func() []byte {
// Old code: buf = append(buf, "data: "...); buf = append(buf, chunkJSON...); buf = append(buf, '\n', '\n')
data := []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`)
buf := make([]byte, 0, len(data)+8)
buf = append(buf, "data: "...)
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return buf
}(),
},
{
name: "OpenAI responses format with event type (old inference.go pattern)",
eventType: "response.output_item.added",
data: []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`),
want: func() []byte {
// Old code: buf = append(buf, "event: "...); buf = append(buf, eventType...); buf = append(buf, "\ndata: "...); ...
eventType := "response.output_item.added"
data := []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`)
buf := make([]byte, 0, len(eventType)+len(data)+16)
buf = append(buf, "event: "...)
buf = append(buf, eventType...)
buf = append(buf, "\ndata: "...)
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return buf
}(),
},
{
name: "error event (old interceptor pattern)",
eventType: "error",
data: []byte(`{"error":"stream interrupted"}`),
want: func() []byte {
data := []byte(`{"error":"stream interrupted"}`)
buf := make([]byte, 0, len(data)+24)
buf = append(buf, "event: error\ndata: "...)
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return buf
}(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent(tt.eventType, tt.data)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got := buf[:n]
if string(got) != string(tt.want) {
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, tt.want)
}
})
}
}
// TestSSEStreamReaderSendErrorByteAccuracy verifies SendError matches
// the old "event: error\ndata: ..." manual assembly.
func TestSSEStreamReaderSendErrorByteAccuracy(t *testing.T) {
r := NewSSEStreamReader()
errorJSON := []byte(`{"error":{"type":"internal_error","message":"An error occurred"}}`)
go func() {
r.SendError(errorJSON)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Must match the old pattern exactly:
// buf = append(buf, "event: error\ndata: "...)
// buf = append(buf, errorJSON...)
// buf = append(buf, '\n', '\n')
want := "event: error\ndata: " + string(errorJSON) + "\n\n"
if got := string(buf[:n]); got != want {
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, want)
}
}
// TestSSEStreamReaderMixedMethodStream simulates a realistic stream
// that uses multiple methods (like router.go does): data events,
// typed events, error, and done marker.
func TestSSEStreamReaderMixedMethodStream(t *testing.T) {
r := NewSSEStreamReader()
expected := []string{
"data: {\"chunk\":1}\n\n",
"event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
"data: {\"chunk\":2}\n\n",
"event: error\ndata: {\"error\":\"timeout\"}\n\n",
"data: [DONE]\n\n",
}
go func() {
r.SendEvent("", []byte(`{"chunk":1}`))
r.SendEvent("response.delta", []byte(`{"delta":"hi"}`))
r.SendEvent("", []byte(`{"chunk":2}`))
r.SendError([]byte(`{"error":"timeout"}`))
r.SendDone()
r.Done()
}()
buf := make([]byte, 4096)
for i, want := range expected {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
if got := string(buf[:n]); got != want {
t.Errorf("event %d:\n got: %q\n want: %q", i, got, want)
}
}
// Should be EOF after all events
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected EOF, got err=%v n=%d", err, n)
}
}
// TestSSEStreamReaderRawAndWrapperMixed simulates the router.go pattern
// where raw Send (for Bedrock/passthrough) is mixed with wrapper methods.
func TestSSEStreamReaderRawAndWrapperMixed(t *testing.T) {
r := NewSSEStreamReader()
// Simulate: Bedrock binary event (raw), followed by SSE events, then done
bedrockBinary := []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x2A} // fake binary
preformattedSSE := []byte("event: content_block_delta\ndata: {\"delta\":\"test\"}\n\n")
expected := [][]byte{
bedrockBinary,
preformattedSSE,
[]byte("data: {\"final\":true}\n\n"),
}
go func() {
r.Send(bedrockBinary) // raw binary passthrough
r.Send(preformattedSSE) // pre-formatted SSE string
r.SendEvent("", []byte(`{"final":true}`)) // wrapper method
r.Done()
}()
buf := make([]byte, 4096)
for i, want := range expected {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
if string(buf[:n]) != string(want) {
t.Errorf("event %d:\n got: %q\n want: %q", i, buf[:n], want)
}
}
}
// TestSSEStreamReaderSendEventEmptyData verifies behavior with empty data payload.
func TestSSEStreamReaderSendEventEmptyData(t *testing.T) {
tests := []struct {
name string
eventType string
want string
}{
{
name: "empty data no event type",
eventType: "",
want: "data: \n\n",
},
{
name: "empty data with event type",
eventType: "heartbeat",
want: "event: heartbeat\ndata: \n\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent(tt.eventType, []byte{})
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := string(buf[:n]); got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
// TestSSEStreamReaderSendEventNilData verifies behavior with nil data payload.
func TestSSEStreamReaderSendEventNilData(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent("", nil)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// nil data should produce same as empty: "data: \n\n"
if got := string(buf[:n]); got != "data: \n\n" {
t.Errorf("got %q, want %q", got, "data: \n\n")
}
}
// TestSSEStreamReaderSendEventLargePayload verifies no corruption with large JSON payloads.
func TestSSEStreamReaderSendEventLargePayload(t *testing.T) {
r := NewSSEStreamReader()
// Build a large JSON payload (~64KB, larger than typical ReadBufferSize)
largeContent := make([]byte, 65536)
for i := range largeContent {
largeContent[i] = 'A' + byte(i%26)
}
data := append([]byte(`{"content":"`), largeContent...)
data = append(data, '"', '}')
go func() {
r.SendEvent("response.delta", data)
r.Done()
}()
// Read the entire event using small buffer to exercise partial reads
var result []byte
buf := make([]byte, 1024)
for {
n, err := r.Read(buf)
result = append(result, buf[:n]...)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
want := "event: response.delta\ndata: " + string(data) + "\n\n"
if string(result) != want {
t.Errorf("large payload mismatch: got len=%d, want len=%d", len(result), len(want))
// Check prefix and suffix for debugging
if len(result) > 40 {
t.Errorf(" got prefix: %q", result[:40])
t.Errorf(" want prefix: %q", want[:40])
}
}
}
// TestSSEStreamReaderMidStreamDisconnect simulates a client disconnecting
// mid-stream while the producer is using SendEvent.
func TestSSEStreamReaderMidStreamDisconnect(t *testing.T) {
r := NewSSEStreamReader()
producerDone := make(chan int) // reports how many events were sent
go func() {
sent := 0
for i := 0; i < 100; i++ {
if !r.SendEvent("", []byte(fmt.Sprintf(`{"chunk":%d}`, i))) {
break
}
sent++
}
close(producerDone)
}()
// Read a few events then simulate client disconnect
buf := make([]byte, 4096)
for i := 0; i < 3; i++ {
_, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
}
// Client disconnects
r.Close()
// Producer should stop promptly
<-producerDone
}
// TestSSEStreamReaderSendErrorThenDone verifies the handler pattern
// of sending an error event and immediately closing the stream.
func TestSSEStreamReaderSendErrorThenDone(t *testing.T) {
r := NewSSEStreamReader()
go func() {
// Send a few normal events
r.SendEvent("", []byte(`{"chunk":1}`))
r.SendEvent("", []byte(`{"chunk":2}`))
// Error occurs, send error and stop
r.SendError([]byte(`{"error":"rate_limit"}`))
r.Done()
}()
buf := make([]byte, 4096)
expected := []string{
"data: {\"chunk\":1}\n\n",
"data: {\"chunk\":2}\n\n",
"event: error\ndata: {\"error\":\"rate_limit\"}\n\n",
}
for i, want := range expected {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
if got := string(buf[:n]); got != want {
t.Errorf("event %d: got %q, want %q", i, got, want)
}
}
// Should be EOF (stream ended after error, no [DONE] marker)
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected EOF after error event, got err=%v n=%d data=%q", err, n, buf[:n])
}
}
// TestSSEStreamReaderSendDoneByteExact verifies SendDone produces
// exactly "data: [DONE]\n\n" — the standard OpenAI SSE terminator.
func TestSSEStreamReaderSendDoneByteExact(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendDone()
r.Done()
}()
// Use exact-size buffer to verify no extra bytes
want := []byte("data: [DONE]\n\n")
buf := make([]byte, len(want))
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if n != len(want) {
t.Errorf("expected %d bytes, got %d", len(want), n)
}
if string(buf[:n]) != string(want) {
t.Errorf("got %q, want %q", buf[:n], want)
}
}
// TestSSEStreamReaderConcurrentSendEvent verifies thread safety of SendEvent
// with multiple concurrent producers (not a real pattern but validates safety).
func TestSSEStreamReaderConcurrentSendEvent(t *testing.T) {
r := NewSSEStreamReader()
const numProducers = 5
const eventsPerProducer = 20
var wg sync.WaitGroup
wg.Add(numProducers)
// Launch multiple producers
for p := 0; p < numProducers; p++ {
go func(id int) {
defer wg.Done()
for i := 0; i < eventsPerProducer; i++ {
if !r.SendEvent("", []byte(fmt.Sprintf(`{"p":%d,"i":%d}`, id, i))) {
return
}
}
}(p)
}
// Close after all producers finish
go func() {
wg.Wait()
r.Done()
}()
// Consume all events
buf := make([]byte, 4096)
count := 0
for {
n, err := r.Read(buf)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Every event must be a valid SSE data line
got := string(buf[:n])
if len(got) < 8 || got[:6] != "data: " || got[len(got)-2:] != "\n\n" {
t.Errorf("event %d: invalid SSE format: %q", count, got)
}
count++
}
if count != numProducers*eventsPerProducer {
t.Errorf("got %d events, want %d", count, numProducers*eventsPerProducer)
}
}
func TestSSEStreamReaderCloseUnblocksProducer(t *testing.T) {
r := NewSSEStreamReader()
done := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(done)
// Fill the channel buffer (cap=1)
r.Send([]byte("data: first\n\n"))
// This Send should block until Close is called
r.Send([]byte("data: second\n\n"))
// After Close, the next Send should return false
if r.Send([]byte("data: third\n\n")) {
select {
case errCh <- fmt.Errorf("Send should return false after Close"):
default:
}
}
}()
// Close unblocks the blocked Send
r.Close()
<-done
select {
case err := <-errCh:
t.Error(err)
default:
}
}

View File

@@ -0,0 +1,96 @@
// Package lib provides core functionality for the Bifrost HTTP service.
// This file contains JSON schema validation for config files.
package lib
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/santhosh-tekuri/jsonschema/v6"
)
// localSchemaCandidates lists paths (relative to CWD) where config.schema.json may be found
// when running from a source checkout. Checked in order before falling back to the remote URL.
var localSchemaCandidates = []string{
"config.schema.json", // running from transports/
"../config.schema.json", // running from transports/bifrost-http/
"transports/config.schema.json", // running from repo root
}
// tryLoadLocalSchema attempts to read config.schema.json from known local paths.
// Returns nil if none are found.
func tryLoadLocalSchema() []byte {
for _, p := range localSchemaCandidates {
data, err := os.ReadFile(p)
if err == nil {
return data
}
}
return nil
}
// ValidateConfigSchema validates config data against the JSON schema.
// Returns nil if valid, or a formatted error describing all validation failures.
// An optional schemaOverride can be provided to use a local schema instead of fetching from the remote URL.
func ValidateConfigSchema(data []byte, schemaOverride ...[]byte) error {
var configSchemaJSONBytes []byte
if len(schemaOverride) > 0 && len(schemaOverride[0]) > 0 {
configSchemaJSONBytes = schemaOverride[0]
} else if localSchema := tryLoadLocalSchema(); localSchema != nil {
// Prefer the local schema file from the source checkout when available.
// This avoids validating against a potentially stale remote schema.
configSchemaJSONBytes = localSchema
} else {
// Pulling config.schema from https://www.getbifrost.ai/schema
configSchemaJSON, err := http.Get("https://www.getbifrost.ai/schema")
if err != nil {
return fmt.Errorf("failed to get config schema: %w", err)
}
defer configSchemaJSON.Body.Close()
var readErr error
configSchemaJSONBytes, readErr = io.ReadAll(configSchemaJSON.Body)
if readErr != nil {
logger.Warn("failed to download config schema: %v. running without config.json schema validation", readErr)
return nil
}
}
// Parse the schema JSON
schemaDoc, err := jsonschema.UnmarshalJSON(bytes.NewReader(configSchemaJSONBytes))
if err != nil {
return fmt.Errorf("failed to parse config schema JSON: %w", err)
}
c := jsonschema.NewCompiler()
if err := c.AddResource("config.schema.json", schemaDoc); err != nil {
return fmt.Errorf("failed to add config schema resource: %w", err)
}
// Compile the schema
compiledSchema, err := c.Compile("config.schema.json")
if err != nil {
return fmt.Errorf("failed to compile config schema: %w", err)
}
var v any
if err := json.Unmarshal(data, &v); err != nil {
return fmt.Errorf("invalid JSON: %w", err)
}
err = compiledSchema.Validate(v)
if err == nil {
return nil
}
// Format validation errors for better readability
return formatValidationError(err)
}
// formatValidationError converts jsonschema validation errors into user-friendly messages
func formatValidationError(err error) error {
validationErr, ok := err.(*jsonschema.ValidationError)
if !ok {
return err
}
// Use the GoString format which provides detailed hierarchical output
return fmt.Errorf("schema validation failed:\n%s", validationErr.GoString())
}

File diff suppressed because it is too large Load Diff