package handlers import ( "bytes" "context" "encoding/base64" "errors" "fmt" "io" "slices" "strings" "sync/atomic" "time" "github.com/google/uuid" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) var loggingSkipPaths = []string{"/health", "/_next", "/api/dev"} var realtimeTransportPaths = buildRealtimeTransportPathSet() // SecurityHeadersMiddleware sets security-related HTTP headers on every response. // This should wrap the outermost handler so all responses (API, UI, errors) include these headers. func SecurityHeadersMiddleware() schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("X-Frame-Options", "DENY") ctx.Response.Header.Set("X-Content-Type-Options", "nosniff") ctx.Response.Header.Set("Referrer-Policy", "strict-origin-when-cross-origin") ctx.Response.Header.Set("Content-Security-Policy", "frame-ancestors 'none'") ctx.Response.Header.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") // Only set HSTS when serving over HTTPS (detected via reverse proxy header or direct TLS) if string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" || ctx.IsTLS() { ctx.Response.Header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } next(ctx) } } } // CorsMiddleware handles CORS headers for localhost and configured allowed origins func CorsMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { startTime := time.Now() // skip logging if it's a /health check request if slices.IndexFunc(loggingSkipPaths, func(path string) bool { return strings.HasPrefix(string(ctx.RequestURI()), path) }) != -1 { goto corsFlow } defer func() { statusCode := ctx.Response.Header.StatusCode() level := schemas.LogLevelInfo if statusCode >= 500 { level = schemas.LogLevelError } else if statusCode >= 400 { level = schemas.LogLevelWarn } logBuilder := logger.LogHTTPRequest(level, "request completed"). Str("http.method", string(ctx.Method())). Str("http.target", string(ctx.RequestURI())). Int("http.status_code", statusCode). Int64("http.request_duration_ms", time.Since(startTime).Milliseconds()). Str("http.remote_addr", ctx.RemoteAddr().String()). Str("http.user_agent", string(ctx.Request.Header.UserAgent())) if traceID, ok := ctx.UserValue(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { logBuilder = logBuilder.Str("trace_id", traceID) } logBuilder.Send() }() corsFlow: origin := string(ctx.Request.Header.Peek("Origin")) allowed := IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins) // Credentialed responses are sent when the origin is not matched solely by a // wildcard AllowedOrigins — i.e. the origin is localhost or explicitly listed. credentialed := !slices.Contains(config.ClientConfig.AllowedOrigins, "*") || isLocalhostOrigin(origin) || slices.Contains(config.ClientConfig.AllowedOrigins, origin) allowedHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key", "X-OpenAI-Agents-SDK"} if slices.Contains(config.ClientConfig.AllowedHeaders, "*") { if credentialed { // Per the Fetch spec, Access-Control-Allow-Headers: * is NOT treated as a // wildcard when Access-Control-Allow-Credentials: true is set — browsers // interpret it as a literal header name. For credentialed preflight requests, // reflect back the requested headers instead. if requestedHeaders := string(ctx.Request.Header.Peek("Access-Control-Request-Headers")); requestedHeaders != "" { allowedHeaders = []string{requestedHeaders} } // For non-preflight requests (no Access-Control-Request-Headers), keep defaults. } else { allowedHeaders = []string{"*"} } } else if len(config.ClientConfig.AllowedHeaders) > 0 { // append allowed headers from config to the default headers for _, header := range config.ClientConfig.AllowedHeaders { if !slices.Contains(allowedHeaders, header) { allowedHeaders = append(allowedHeaders, header) } } } // Check if origin is allowed (localhost always allowed + configured origins) if allowed { ctx.Response.Header.Set("Access-Control-Allow-Origin", origin) ctx.Response.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD") ctx.Response.Header.Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) if credentialed { ctx.Response.Header.Set("Access-Control-Allow-Credentials", "true") } ctx.Response.Header.Set("Access-Control-Max-Age", "86400") // Vary: Origin tells caches that the response varies based on the Origin // request header, preventing incorrect CORS headers from being served. ctx.Response.Header.Set("Vary", "Origin") } // Handle preflight OPTIONS requests if string(ctx.Method()) == "OPTIONS" { if allowed { ctx.SetStatusCode(fasthttp.StatusOK) } else { ctx.SetStatusCode(fasthttp.StatusForbidden) } return } next(ctx) } } } // RequestDecompressionMiddleware transparently decompresses compressed request bodies. // Two paths based on compressed Content-Length: // - Large or chunked (CL > threshold or CL unknown): streaming decompression via // SetBodyStream, avoiding full body materialization. Uses pooled gzip readers // matching the response-side pattern in core/providers/utils. // - Small (CL ≤ threshold): buffered decompression via io.ReadAll + SetBodyRaw, // with decompression bomb protection via MaxRequestBodySizeMB. func RequestDecompressionMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { if len(ctx.Request.Header.ContentEncoding()) == 0 { next(ctx) return } if shouldStreamDecompress(config, ctx) { cleanup, applied, err := streamingDecompress(ctx) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid compressed request body: %v", err)) return } if applied { next(ctx) cleanup() return } // No body stream available (StreamRequestBody not enabled) — fall // through to the buffered decompression path below. } // Buffered path: small compressed request — materialize fully. maxRequestBodyBytes := 100 * 1024 * 1024 // default 100 MB (matches decodeRequestBodyWithLimit fallback) if config != nil && config.ClientConfig.MaxRequestBodySizeMB > 0 { maxRequestBodyBytes = config.ClientConfig.MaxRequestBodySizeMB * 1024 * 1024 } body, err := decodeRequestBodyWithLimit(&ctx.Request, maxRequestBodyBytes) if errors.Is(err, errRequestBodyTooLarge) { SendError(ctx, fasthttp.StatusRequestEntityTooLarge, fmt.Sprintf("decompressed request body exceeds max allowed size of %d bytes", maxRequestBodyBytes)) return } if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid compressed request body: %v", err)) return } ctx.Request.SetBodyRaw(body) ctx.Request.Header.Del(fasthttp.HeaderContentEncoding) ctx.Request.Header.Del(fasthttp.HeaderContentLength) next(ctx) } } } // shouldStreamDecompress returns true when the compressed request body should // use streaming decompression rather than full materialization. Uses the // config threshold (set by enterprise from LargePayloadConfig.RequestThresholdBytes) // or falls back to DefaultLargePayloadRequestThresholdBytes. // Chunked requests (unknown size) always stream to be safe. func shouldStreamDecompress(config *lib.Config, ctx *fasthttp.RequestCtx) bool { contentLength := ctx.Request.Header.ContentLength() // Chunked transfer encoding: fasthttp reports -1. Size unknown, stream to be safe. if contentLength < 0 { return true } var threshold int64 = schemas.DefaultLargePayloadRequestThresholdBytes if config != nil && config.StreamingDecompressThreshold > 0 { threshold = config.StreamingDecompressThreshold } return int64(contentLength) > threshold } // streamingDecompress wraps the request body stream with a streaming decompression // reader, avoiding full body materialization for large compressed requests. // Returns (cleanup, applied, err): // - applied=true: body stream was wrapped; caller must invoke cleanup after the // handler chain completes and the body is fully consumed. // - applied=false: no body stream available (StreamRequestBody not enabled on the // server). Caller should fall back to the buffered decompression path. func streamingDecompress(ctx *fasthttp.RequestCtx) (cleanup func(), applied bool, err error) { bodyStream := ctx.RequestBodyStream() if bodyStream == nil { return func() {}, false, nil } encoding := strings.ToLower(strings.TrimSpace( string(ctx.Request.Header.ContentEncoding()), )) decompReader, cleanup, err := newDecompressReader(bodyStream, encoding) if err != nil { return nil, false, err } ctx.Request.SetBodyStream(decompReader, -1) ctx.Request.Header.Del(fasthttp.HeaderContentEncoding) ctx.Request.Header.Del(fasthttp.HeaderContentLength) return cleanup, true, nil } var errRequestBodyTooLarge = errors.New("decompressed request body exceeds max allowed size") // decodeRequestBodyWithLimit decodes the request body with a limit on the size of the body. func decodeRequestBodyWithLimit(req *fasthttp.Request, maxRequestBodyBytes int) ([]byte, error) { encoding := strings.ToLower(strings.TrimSpace(string(req.Header.ContentEncoding()))) bodyReader := bytes.NewReader(req.Body()) var reader io.Reader = bodyReader cleanup := func() {} if encoding != "" { var err error reader, cleanup, err = newDecompressReader(bodyReader, encoding) if err != nil { return nil, err } } defer cleanup() if maxRequestBodyBytes <= 0 { maxRequestBodyBytes = 100 * 1024 * 1024 // 100 MB hard cap } limitedReader := &io.LimitedReader{R: reader, N: int64(maxRequestBodyBytes + 1)} body, err := io.ReadAll(limitedReader) if err != nil { return nil, err } if len(body) > maxRequestBodyBytes { return nil, errRequestBodyTooLarge } return body, nil } // newDecompressReader wraps r with a decompression reader for the given encoding. // All encodings use pooled readers from core/providers/utils. The returned cleanup // function must be called when the reader is no longer needed. func newDecompressReader(r io.Reader, encoding string) (io.Reader, func(), error) { switch encoding { case "gzip": gz, err := providerUtils.AcquireGzipReader(r) if err != nil { return nil, nil, err } return gz, func() { providerUtils.ReleaseGzipReader(gz) }, nil case "deflate": fr, err := providerUtils.AcquireFlateReader(r) if err != nil { return nil, nil, err } return fr, func() { providerUtils.ReleaseFlateReader(fr) }, nil case "br": br := providerUtils.AcquireBrotliReader(r) return br, func() { providerUtils.ReleaseBrotliReader(br) }, nil case "zstd": dec, err := providerUtils.AcquireZstdDecoder(r) if err != nil { return nil, nil, err } return dec, func() { providerUtils.ReleaseZstdDecoder(dec) }, nil default: return nil, nil, fmt.Errorf("%w: %q", fasthttp.ErrContentEncodingUnsupported, encoding) } } // TransportInterceptorMiddleware runs all plugin HTTP transport interceptors. // It converts the fasthttp request to a serializable HTTPRequest, runs all plugin interceptors, // and applies any modifications back to the fasthttp context. func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { plugins := config.GetLoadedHTTPTransportPlugins() if len(plugins) == 0 { next(ctx) return } // Get or create BifrostContext from fasthttp context bifrostCtx := getBifrostContextFromFastHTTP(ctx) // Acquire pooled request req := schemas.AcquireHTTPRequest() defer schemas.ReleaseHTTPRequest(req) fasthttpToHTTPRequest(ctx, req) // Run plugin interceptors for _, plugin := range plugins { pluginName := plugin.GetName() pluginCtx := bifrostCtx.WithPluginScope(&pluginName) resp, err := plugin.HTTPTransportPreHook(pluginCtx, req) pluginCtx.ReleasePluginScope() if err != nil { // Short-circuit with error — drain plugin logs before returning if logs := bifrostCtx.DrainPluginLogs(); len(logs) > 0 { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, logs) } ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetBodyString(err.Error()) return } if resp != nil { // Short-circuit with response — drain plugin logs before returning if logs := bifrostCtx.DrainPluginLogs(); len(logs) > 0 { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, logs) } applyHTTPResponseToCtx(ctx, resp) return } // If we got here, the plugin may have modified req in-place } // Drain pre-hook plugin logs and store on fasthttp context for trace attachment if preHookLogs := bifrostCtx.DrainPluginLogs(); len(preHookLogs) > 0 { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, preHookLogs) } // Apply modifications back to fasthttp context applyHTTPRequestToCtx(ctx, req) // Adding user values for key, value := range bifrostCtx.GetUserValues() { ctx.SetUserValue(key, value) } next(ctx) // For streaming responses, store a callback to run post-hooks after the stream ends. // The streaming handler calls this BEFORE reader.Done() so that errors can // still be sent as SSE events. applyResponse=false because the response is // already on the wire and mutating ctx.Response would corrupt the chunked stream. // // IMPORTANT: The callback must NOT access ctx — fasthttp recycles RequestCtx // after the response body stream completes. All needed data is eagerly captured // here (while ctx is still valid) and passed through the closure. if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { // Verify the completer slot exists before allocating pooled snapshots. // The streaming handler pre-allocates this *atomic.Value; if absent, // skip work to avoid leaking pooled HTTPRequest/HTTPResponse objects. slot, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPostHookCompleter).(*atomic.Value) if !ok { return } // Eagerly snapshot request/response from ctx before it can be recycled. capturedReq := lib.BuildHTTPRequestFromFastHTTP(ctx) capturedResp := lib.BuildHTTPResponseFromFastHTTP(ctx) // Snapshot pre-hook transport plugin logs already accumulated on ctx. var preHookLogs []schemas.PluginLogEntry if logs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { preHookLogs = logs } completer := func() ([]schemas.PluginLogEntry, error) { defer schemas.ReleaseHTTPRequest(capturedReq) defer schemas.ReleaseHTTPResponse(capturedResp) postHookLogs, err := runTransportPostHooksCaptured(capturedReq, capturedResp, plugins, bifrostCtx) allLogs := preHookLogs if len(postHookLogs) > 0 { allLogs = append(allLogs, postHookLogs...) } return allLogs, err } // Store the completer in the atomic.Value slot that the streaming handler // placed on ctx. The goroutine reads from its closure-captured copy of // the slot, avoiding any ctx access after the handler returns. slot.Store(completer) return } _ = runTransportPostHooks(ctx, plugins, bifrostCtx, true) } } } // runTransportPostHooks runs HTTPTransportPostHook for all plugins in reverse order, // drains plugin logs, and applies the response back to the fasthttp context. // Used for both non-streaming (inline) and streaming (deferred callback) paths. // // Transport-level plugin logs are stored in fasthttp UserValues (keyed by // BifrostContextKeyTransportPluginLogs) rather than directly on BifrostContext, // because transport hooks operate at the fasthttp layer before/after the core // BifrostContext lifecycle. These logs are merged into the trace by the // TracingMiddleware at trace completion, alongside core-level plugin logs // which travel through BifrostContext → Trace → AttachPluginLogs. func runTransportPostHooks(ctx *fasthttp.RequestCtx, plugins []schemas.HTTPTransportPlugin, bifrostCtx *schemas.BifrostContext, applyResponse bool) error { shouldApplyShortCircuit := applyResponse httpResp := schemas.AcquireHTTPResponse() defer schemas.ReleaseHTTPResponse(httpResp) fasthttpResponseToHTTPResponse(ctx, httpResp) // Build request from current fasthttp state (original pooled req may have been released) req := schemas.AcquireHTTPRequest() defer schemas.ReleaseHTTPRequest(req) fasthttpToHTTPRequest(ctx, req) // Run http post-hooks in reverse order for i := len(plugins) - 1; i >= 0; i-- { plugin := plugins[i] pluginName := plugin.GetName() pluginCtx := bifrostCtx.WithPluginScope(&pluginName) err := plugin.HTTPTransportPostHook(pluginCtx, req, httpResp) pluginCtx.ReleasePluginScope() if err != nil { logger.Warn("error in HTTPTransportPostHook for plugin %s: %s", pluginName, err.Error()) // Drain plugin logs before returning on error if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { if existing, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, append(existing, postHookLogs...)) } else { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, postHookLogs) } } if shouldApplyShortCircuit { applyHTTPResponseToCtx(ctx, httpResp) } return fmt.Errorf("transport post-hook plugin %s: %w", pluginName, err) } } // Drain post-hook plugin logs and merge with pre-hook logs if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { if existing, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, append(existing, postHookLogs...)) } else { ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, postHookLogs) } } if shouldApplyShortCircuit { applyHTTPResponseToCtx(ctx, httpResp) } return nil } // runTransportPostHooksCaptured is the goroutine-safe variant of runTransportPostHooks. // It uses pre-captured HTTPRequest and HTTPResponse snapshots instead of reading from // a fasthttp RequestCtx, which may have been recycled by the time this runs in a // streaming goroutine. Returns accumulated plugin logs (instead of writing them to // ctx.UserValue) so the caller can forward them to the trace completer. func runTransportPostHooksCaptured(capturedReq *schemas.HTTPRequest, capturedResp *schemas.HTTPResponse, plugins []schemas.HTTPTransportPlugin, bifrostCtx *schemas.BifrostContext) ([]schemas.PluginLogEntry, error) { // Clone into fresh pooled objects so plugins can mutate without affecting the snapshots. req := schemas.AcquireHTTPRequest() defer schemas.ReleaseHTTPRequest(req) req.Method = capturedReq.Method req.Path = capturedReq.Path for k, v := range capturedReq.Headers { req.Headers[k] = v } for k, v := range capturedReq.Query { req.Query[k] = v } for k, v := range capturedReq.PathParams { req.PathParams[k] = v } httpResp := schemas.AcquireHTTPResponse() defer schemas.ReleaseHTTPResponse(httpResp) httpResp.StatusCode = capturedResp.StatusCode for k, v := range capturedResp.Headers { httpResp.Headers[k] = v } var allLogs []schemas.PluginLogEntry // Run http post-hooks in reverse order for i := len(plugins) - 1; i >= 0; i-- { plugin := plugins[i] pluginName := plugin.GetName() pluginCtx := bifrostCtx.WithPluginScope(&pluginName) err := plugin.HTTPTransportPostHook(pluginCtx, req, httpResp) pluginCtx.ReleasePluginScope() if err != nil { logger.Warn("error in HTTPTransportPostHook for plugin %s: %s", pluginName, err.Error()) if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { allLogs = append(allLogs, postHookLogs...) } return allLogs, fmt.Errorf("transport post-hook plugin %s: %w", pluginName, err) } } // Drain post-hook plugin logs if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { allLogs = append(allLogs, postHookLogs...) } return allLogs, nil } // getBifrostContextFromFastHTTP gets or creates a BifrostContext from fasthttp context. func getBifrostContextFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.BifrostContext { return schemas.NewBifrostContext(ctx, schemas.NoDeadline) } // fasthttpToHTTPRequest populates a pooled HTTPRequest from fasthttp context. func fasthttpToHTTPRequest(ctx *fasthttp.RequestCtx, req *schemas.HTTPRequest) { req.Method = string(ctx.Method()) req.Path = string(ctx.Path()) // Copy headers for key, value := range ctx.Request.Header.All() { req.Headers[string(key)] = string(value) } // Copy query params for key, value := range ctx.Request.URI().QueryArgs().All() { req.Query[string(key)] = string(value) } // Copy path parameters from user values // The fasthttp router stores path variables (like {file_id}, {model}) as user values // We extract all string user values that are likely path parameters ctx.VisitUserValuesAll(func(key, value any) { // Only process string keys and string values keyStr, keyIsString := key.(string) valueStr, valueIsString := value.(string) if !keyIsString || !valueIsString { return } // Skip internal Bifrost system keys and tracing keys if strings.HasPrefix(keyStr, "bifrost-") || keyStr == "BifrostContextKeyRequestID" || keyStr == "trace_id" || keyStr == "span_id" { return } // Store as path parameter req.PathParams[keyStr] = valueStr }) // Skip body copy for large payloads. // Check threshold first (set by RequestThresholdMiddleware before this middleware runs) // because the large-payload-mode flag is only set later inside the handler hook. if threshold, ok := ctx.UserValue(schemas.BifrostContextKeyLargePayloadRequestThreshold).(int64); ok && threshold > 0 { cl := int64(ctx.Request.Header.ContentLength()) // Skip body copy when CL exceeds threshold OR CL is unknown (streaming/ // chunked, e.g. after streaming decompression deletes the header). if cl > threshold || cl < 0 { return } } if isLargePayload, ok := ctx.UserValue(schemas.BifrostContextKeyLargePayloadMode).(bool); ok && isLargePayload { return } body := ctx.Request.Body() if len(body) > 0 { req.Body = make([]byte, len(body)) copy(req.Body, body) } } // applyHTTPRequestToCtx applies modifications from HTTPRequest back to fasthttp context. func applyHTTPRequestToCtx(ctx *fasthttp.RequestCtx, req *schemas.HTTPRequest) { // If path/method is different, throw error if req.Method != string(ctx.Method()) || req.Path != string(ctx.Path()) { logger.Error("request method/path mismatch: %s %s != %s %s", req.Method, req.Path, string(ctx.Method()), string(ctx.Path())) SendError(ctx, fasthttp.StatusConflict, "request method/path was modified by a plugin, this is not allowed") return } // Apply headers for key, value := range req.Headers { ctx.Request.Header.Set(key, value) } // Apply query params for key, value := range req.Query { ctx.Request.URI().QueryArgs().Set(key, value) } // Apply body if set if req.Body != nil { ctx.Request.SetBody(req.Body) } } // applyHTTPResponseToCtx writes a short-circuit response to fasthttp context. func applyHTTPResponseToCtx(ctx *fasthttp.RequestCtx, resp *schemas.HTTPResponse) { ctx.SetStatusCode(resp.StatusCode) for key, value := range resp.Headers { ctx.Response.Header.Set(key, value) } if resp.Body != nil { ctx.SetBody(resp.Body) } } // fasthttpResponseToHTTPResponse populates a pooled HTTPResponse from fasthttp context. func fasthttpResponseToHTTPResponse(ctx *fasthttp.RequestCtx, resp *schemas.HTTPResponse) { resp.StatusCode = ctx.Response.StatusCode() for key, value := range ctx.Response.Header.All() { resp.Headers[string(key)] = string(value) } // Skip response body copy for streaming (SSE) responses — the body is an active // io.Reader consumed by fasthttp's writeBodyChunked. Calling Body() would race // with the chunked writer (Body() drains and closes the bodyStream). if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { return } // Skip response body copy when large payload/response mode is active — the response is // streamed directly to the client and materializing it here would spike memory. if isLargePayload, ok := ctx.UserValue(schemas.BifrostContextKeyLargePayloadMode).(bool); ok && isLargePayload { return } if isLargeResponse, ok := ctx.UserValue(lib.FastHTTPUserValueLargeResponseMode).(bool); ok && isLargeResponse { return } // Also skip if response Content-Length exceeds the configured response threshold. if threshold, ok := ctx.UserValue(schemas.BifrostContextKeyLargeResponseThreshold).(int64); ok && threshold > 0 { if int64(ctx.Response.Header.ContentLength()) > threshold { return } } body := ctx.Response.Body() if len(body) > 0 { resp.Body = make([]byte, len(body)) copy(resp.Body, body) } } // validateSession checks if a session token is valid func validateSession(_ *fasthttp.RequestCtx, store configstore.ConfigStore, token string) bool { session, err := store.GetSession(context.Background(), token) if err != nil || session == nil { return false } if session.ExpiresAt.Before(time.Now()) { return false } return true } // isInferenceWSEndpoint returns true for WebSocket endpoints that should use // standard inference auth (Bearer/Basic/VK) rather than dashboard session tokens. func isInferenceWSEndpoint(path string) bool { for strings.HasPrefix(path, "/openai/") { path = strings.TrimPrefix(path, "/openai") } switch path { case "/v1/responses", "/responses", "/v1/realtime", "/realtime": return true default: return false } } func buildRealtimeTransportPathSet() map[string]struct{} { paths := map[string]struct{}{} for _, path := range integrations.OpenAIRealtimePaths("") { paths[path] = struct{}{} } for _, path := range integrations.OpenAIRealtimePaths("/openai") { paths[path] = struct{}{} } for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("") { paths[path] = struct{}{} } for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("/openai") { paths[path] = struct{}{} } return paths } func isRealtimeTransportEndpoint(path string) bool { _, ok := realtimeTransportPaths[path] return ok } // AuthMiddleware is a middleware that handles authentication for the API. type AuthMiddleware struct { store configstore.ConfigStore whitelistedRoutes atomic.Pointer[[]string] authConfig atomic.Pointer[configstore.AuthConfig] wsTicketStore *WSTicketStore } // InitAuthMiddleware initializes the auth middleware. func InitAuthMiddleware(store configstore.ConfigStore, wsTicketStore *WSTicketStore) (*AuthMiddleware, error) { if store == nil { return nil, fmt.Errorf("store is not present") } authConfig, err := store.GetAuthConfig(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get auth config from store: %v", err) } am := &AuthMiddleware{ store: store, authConfig: atomic.Pointer[configstore.AuthConfig]{}, wsTicketStore: wsTicketStore, } am.authConfig.Store(authConfig) // Load whitelisted routes from client config clientConfig, err := store.GetClientConfig(context.Background()) if err == nil && clientConfig != nil { am.whitelistedRoutes.Store(&clientConfig.WhitelistedRoutes) } else { emptyRoutes := []string{} am.whitelistedRoutes.Store(&emptyRoutes) } return am, nil } func (m *AuthMiddleware) UpdateAuthConfig(authConfig *configstore.AuthConfig) { m.authConfig.Store(authConfig) } // UpdateWhitelistedRoutes updates the configured whitelisted routes that bypass auth middleware. func (m *AuthMiddleware) UpdateWhitelistedRoutes(routes []string) { m.whitelistedRoutes.Store(&routes) } // InferenceMiddleware is for inference requests (including MCP routes) if authConfig is set, it will skip authentication if disableAuthOnInference is true. func (m *AuthMiddleware) InferenceMiddleware() schemas.BifrostHTTPMiddleware { return m.middleware(func(authConfig *configstore.AuthConfig, url string) bool { return authConfig.DisableAuthOnInference }) } // APIMiddleware is for API requests if authConfig is set, it will verify authentication based on the request type. // Three authentication methods are supported: // - Basic auth: Uses username + password validation (no session tracking). Used for inference API calls. // - Bearer token: Uses session validation via validateSession(). Used for dashboard calls. // - WebSocket: Uses session validation via validateSession() with token from query parameters. // // Basic auth may be acceptable for limited use cases, while Bearer and WebSocket flows provide // session-based authentication suitable for production environments. func (m *AuthMiddleware) APIMiddleware() schemas.BifrostHTTPMiddleware { systemWhitelistedRoutes := []string{ "/api/session/is-auth-enabled", "/api/session/login", "/api/oauth/callback", "/health", } whitelistedPrefixes := []string{ "/api/oauth/callback", } return m.middleware(func(authConfig *configstore.AuthConfig, url string) bool { if slices.Contains(systemWhitelistedRoutes, url) || slices.IndexFunc(whitelistedPrefixes, func(prefix string) bool { return strings.HasPrefix(url, prefix) }) != -1 { return true } // Check user-configured whitelisted routes if configuredRoutes := m.whitelistedRoutes.Load(); configuredRoutes != nil { if slices.Contains(*configuredRoutes, url) || slices.IndexFunc(*configuredRoutes, func(route string) bool { if strings.HasSuffix(route, "*") { return strings.HasPrefix(url, strings.TrimSuffix(route, "*")) } return false }) != -1 { return true } } return false }) } // middleware is the core authentication middleware that checks if the request should be authenticated or not. func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, string) bool) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { authConfig := m.authConfig.Load() if authConfig == nil || !authConfig.IsEnabled { logger.Debug("auth middleware is disabled because auth config is not present or not enabled") ctx.SetUserValue(schemas.BifrostContextKeySessionToken, "") next(ctx) return } url := string(ctx.Request.URI().RequestURI()) // We skip authorization for the login route if shouldSkip(authConfig, url) { next(ctx) return } if isRealtimeTransportEndpoint(string(ctx.Path())) { next(ctx) return } // If inference is disabled, we skip authorization // Get the authorization header authorization := string(ctx.Request.Header.Peek("Authorization")) if authorization == "" { if string(ctx.Request.Header.Peek("Upgrade")) == "websocket" { path := string(ctx.Path()) if isInferenceWSEndpoint(path) { // Inference WS endpoints (/v1/responses, /v1/realtime) use the same // auth as HTTP inference: Bearer/Basic headers or governance VK validation. // If no Authorization header, fall through to return 401 below // (or the shouldSkip check above already passed them through). } else { // Prefer short-lived ticket-based auth (from POST /api/session/ws-ticket) ticket := string(ctx.Request.URI().QueryArgs().Peek("ticket")) if ticket != "" && m.wsTicketStore != nil { sessionToken := m.wsTicketStore.Consume(ticket) if sessionToken != "" && validateSession(ctx, m.store, sessionToken) { ctx.SetUserValue(schemas.BifrostContextKeySessionToken, sessionToken) next(ctx) return } SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Fallback: legacy ?token= param (for backward compatibility) token := string(ctx.Request.URI().QueryArgs().Peek("token")) if token != "" { if validateSession(ctx, m.store, token) { ctx.SetUserValue(schemas.BifrostContextKeySessionToken, token) next(ctx) return } SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Fallback: cookie-based WS auth cookieToken := string(ctx.Request.Header.Cookie("token")) if cookieToken != "" && validateSession(ctx, m.store, cookieToken) { ctx.SetUserValue(schemas.BifrostContextKeySessionToken, cookieToken) next(ctx) return } SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } } // Cookie-based auth fallback: if no Authorization header, check for the HTTPOnly session cookie. // This supports the dashboard which relies on cookies instead of localStorage tokens. cookieToken := string(ctx.Request.Header.Cookie("token")) if cookieToken != "" && validateSession(ctx, m.store, cookieToken) { ctx.SetUserValue(schemas.BifrostContextKeySessionToken, cookieToken) next(ctx) return } SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Split the authorization header into the scheme and the token scheme, token, ok := strings.Cut(authorization, " ") if !ok { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Checking basic auth for inference calls if scheme == "Basic" { // Decode the base64 token decodedBytes, err := base64.StdEncoding.DecodeString(token) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Split the decoded token into the username and password username, password, ok := strings.Cut(string(decodedBytes), ":") if !ok { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Verify the username and password if authConfig.AdminUserName == nil || username != authConfig.AdminUserName.GetValue() { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } if authConfig.AdminPassword == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Authentication not properly configured") return } compare, err := encrypt.CompareHash(authConfig.AdminPassword.GetValue(), password) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, "Internal Server Error") return } if !compare { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Continue with the next handler next(ctx) return } // Checking bearer auth for dashboard calls if scheme == "Bearer" { // Verify the session if !validateSession(ctx, m.store, token) { // Here we will check if its the base64 of username:password // This is for backward compatibility with the old auth system decodedBytes, err := base64.StdEncoding.DecodeString(token) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } username, password, ok := strings.Cut(string(decodedBytes), ":") if !ok { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Verify the username and password if authConfig.AdminUserName == nil || username != authConfig.AdminUserName.GetValue() { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } if authConfig.AdminPassword == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Authentication not properly configured") return } compare, err := encrypt.CompareHash(authConfig.AdminPassword.GetValue(), password) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, "Internal Server Error") return } if !compare { SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") return } // Continue with the next handler next(ctx) return } // setting up session in the request ctx.SetUserValue(schemas.BifrostContextKeySessionToken, token) // Continue with the next handler next(ctx) return } SendError(ctx, fasthttp.StatusUnauthorized, "Unauthorized") } } } // TracingMiddleware creates distributed traces for requests and forwards completed traces // to observability plugins after the response has been written. // // The middleware: // 1. Extracts parent trace ID from incoming W3C traceparent header (if present) // 2. Creates a new trace in the store (only the lightweight trace ID is stored in context) // 3. Calls the next handler to process the request // 4. After response is written, asynchronously completes the trace and forwards it to observability plugins // // This middleware should be placed early in the middleware chain to capture the full request lifecycle. type TracingMiddleware struct { tracer atomic.Pointer[tracing.Tracer] } // NewTracingMiddleware creates a new tracing middleware func NewTracingMiddleware(tracer *tracing.Tracer) *TracingMiddleware { tm := &TracingMiddleware{ tracer: atomic.Pointer[tracing.Tracer]{}, } tm.tracer.Store(tracer) return tm } // SetObservabilityPlugins sets the observability plugins for the tracing middleware func (m *TracingMiddleware) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { if tracer := m.tracer.Load(); tracer != nil { tracer.SetObservabilityPlugins(obsPlugins) } } // SetTracer sets the tracer for the tracing middleware func (m *TracingMiddleware) SetTracer(tracer *tracing.Tracer) { m.tracer.Store(tracer) } // Middleware returns the middleware function that creates distributed traces for requests and forwards completed traces func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { // Pin the tracer for the lifetime of this request so that a concurrent // SetTracer() swap cannot split a trace across two instances. tracer := m.tracer.Load() if tracer == nil { next(ctx) return } requestID := string(ctx.Request.Header.Peek("x-request-id")) if requestID == "" { requestID = uuid.New().String() // Injecting this back to be picked up by the next middleware ctx.Request.Header.Set("x-request-id", requestID) } // Extract trace ID from W3C traceparent header (if present) // This is the 32-char trace ID that links all spans in a distributed trace inheritedTraceID := tracing.ExtractParentID(&ctx.Request.Header) // Create trace in store - only ID returned (trace data stays in store) traceID := tracer.CreateTrace(inheritedTraceID, requestID) // Only trace ID goes into context (lightweight, no bloat) ctx.SetUserValue(schemas.BifrostContextKeyTraceID, traceID) // Extract parent span ID from W3C traceparent header (if present) // This is the 16-char span ID from the upstream service that should be // set as the ParentID of our root span for proper trace linking in Datadog/etc. parentSpanID := tracing.ExtractTraceParentSpanID(&ctx.Request.Header) if parentSpanID != "" { ctx.SetUserValue(schemas.BifrostContextKeyParentSpanID, parentSpanID) } // Store a trace completion callback for streaming handlers to use. // Accepts transport plugin logs as a parameter so it never reads from // ctx.UserValue — ctx may be recycled by the time this runs in a goroutine. ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func(transportLogs []schemas.PluginLogEntry) { if len(transportLogs) > 0 { tracer.AttachPluginLogs(traceID, transportLogs) } tracer.CompleteAndFlushTrace(traceID) }) // Create root span for the HTTP request spanCtx, rootSpan := tracer.StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) if rootSpan != nil { tracer.SetAttribute(rootSpan, "http.method", string(ctx.Method())) tracer.SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) tracer.SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) // Set root span ID in context for child span creation if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetUserValue(schemas.BifrostContextKeySpanID, spanID) } } defer func() { // Record response status on the root span if rootSpan != nil { tracer.SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) if ctx.Response.StatusCode() >= 400 { tracer.EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) } else { tracer.EndSpan(rootSpan, schemas.SpanStatusOk, "") } } // Check if trace completion is deferred (for streaming requests) // If deferred, the streaming handler will complete the trace after stream ends if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { return } // Attach transport plugin logs to trace before completion if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { tracer.AttachPluginLogs(traceID, transportLogs) } // After response written - async flush tracer.CompleteAndFlushTrace(traceID) }() next(ctx) } } } // GetTracer returns the tracer instance for use by streaming handlers func (m *TracingMiddleware) GetTracer() *tracing.Tracer { return m.tracer.Load() } // GetObservabilityPlugins filters and returns only observability plugins from a list of plugins. // Uses Go type assertion to identify plugins implementing the ObservabilityPlugin interface. func GetObservabilityPlugins(plugins []schemas.BasePlugin) []schemas.ObservabilityPlugin { if len(plugins) == 0 { return nil } obsPlugins := make([]schemas.ObservabilityPlugin, 0) for _, plugin := range plugins { if obsPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { obsPlugins = append(obsPlugins, obsPlugin) } } return obsPlugins }