package handlers import ( "context" "strings" "github.com/bytedance/sonic" "github.com/fasthttp/router" ws "github.com/fasthttp/websocket" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" "github.com/valyala/fasthttp" ) // wsWriter abstracts a WebSocket write target. Both *ws.Conn (pre-session) // and *bfws.Session (post-session, mutex-protected) satisfy this interface. type wsWriter interface { WriteMessage(messageType int, data []byte) error } // WSResponsesHandler handles WebSocket connections for the Responses API WebSocket Mode. // Clients connect via `GET /v1/responses` with a WS upgrade and send `response.create` events. // Each event is routed through the standard Bifrost inference pipeline (PreLLMHook, key selection, // provider call, PostLLMHook) via the HTTP bridge, with native WS upstream as an optimization. type WSResponsesHandler struct { client *bifrost.Bifrost config *lib.Config handlerStore lib.HandlerStore pool *bfws.Pool sessions *bfws.SessionManager upgrader ws.FastHTTPUpgrader } // NewWSResponsesHandler creates a new WebSocket Responses handler. func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSResponsesHandler { maxConns := config.WebSocketConfig.MaxConnections return &WSResponsesHandler{ client: client, config: config, handlerStore: config, pool: pool, sessions: bfws.NewSessionManager(maxConns), upgrader: ws.FastHTTPUpgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { origin := string(ctx.Request.Header.Peek("Origin")) if origin == "" { return true } return IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins) }, }, } } // Close gracefully shuts down all active WebSocket responses sessions. func (h *WSResponsesHandler) Close() { if h == nil || h.sessions == nil { return } h.sessions.CloseAll() } // RegisterRoutes registers the WebSocket Responses endpoint at the base path // and all OpenAI integration paths. func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...) // Base path (outside integration prefix) r.GET("/v1/responses", handler) // OpenAI integration paths (/openai/v1/responses, /openai/responses, /openai/openai/responses) for _, path := range integrations.OpenAIWSResponsesPaths("/openai") { r.GET(path, handler) } } // handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop. func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) { err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) { defer conn.Close() session, sessionErr := h.sessions.Create(conn) if sessionErr != nil { writeWSError(conn, 429, "websocket_connection_limit_reached", sessionErr.Error()) return } defer h.sessions.Remove(conn) // Capture auth headers from the upgrade request for per-event context creation authHeaders := captureAuthHeaders(ctx) h.eventLoop(conn, session, authHeaders) }) if err != nil { logger.Warn("websocket upgrade failed for /v1/responses: %v", err) } } // authHeaders holds auth-related headers captured during the WS upgrade. type authHeaders struct { authorization string virtualKey string apiKey string googAPIKey string baggage string extraHeaders map[string]string } // captureAuthHeaders captures the auth headers from the request. func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders { ah := &authHeaders{ authorization: string(ctx.Request.Header.Peek("Authorization")), virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")), apiKey: string(ctx.Request.Header.Peek("x-api-key")), googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")), baggage: string(ctx.Request.Header.Peek("baggage")), extraHeaders: make(map[string]string), } for key, value := range ctx.Request.Header.All() { k := string(key) lk := strings.ToLower(k) if strings.HasPrefix(lk, "x-bf-") { ah.extraHeaders[k] = string(value) } } return ah } // eventLoop reads events from the client WebSocket and processes them. func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) { for { _, message, err := conn.ReadMessage() if err != nil { if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseNormalClosure) { logger.Warn("websocket read error: %v", err) } return } // Parse the event type var envelope struct { Type string `json:"type"` } if err := sonic.Unmarshal(message, &envelope); err != nil { writeWSError(session, 400, "invalid_request_error", "failed to parse event JSON") continue } switch schemas.WebSocketEventType(envelope.Type) { case schemas.WSEventResponseCreate: h.handleResponseCreate(session, auth, message) default: writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type) } } } // handleResponseCreate processes a response.create event. // Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge. // If native WS upstream fails mid-stream, falls back to HTTP bridge. func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) { var event schemas.WebSocketResponsesEvent if err := sonic.Unmarshal(message, &event); err != nil { writeWSError(session, 400, "invalid_request_error", "failed to parse response.create event") return } // Store override: default to store=true (Codex sends false by default but expects true). // If DisableStore is set in provider config, force store=false. // If client explicitly sets store, respect that value unless DisableStore overrides it. provider, modelName := schemas.ParseModelString(event.Model, "") if provider == "" || modelName == "" { writeWSError(session, 400, "invalid_request_error", "failed to parse model string") return } if providerCfg, cfgErr := h.config.GetProviderConfigRaw(provider); cfgErr == nil && providerCfg.OpenAIConfig != nil && providerCfg.OpenAIConfig.DisableStore { event.Store = schemas.Ptr(false) } else { event.Store = schemas.Ptr(true) } bifrostReq, err := h.convertEventToRequest(&event) if err != nil { writeWSError(session, 400, "invalid_request_error", err.Error()) return } // Extract extra params (unknown fields) and forward them, matching the HTTP path behavior extraParams, extractErr := extractExtraParams(message, wsResponsesKnownFields) if extractErr == nil && len(extraParams) > 0 { if bifrostReq.Params == nil { bifrostReq.Params = &schemas.ResponsesParameters{} } bifrostReq.Params.ExtraParams = extraParams } bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth) if bifrostCtx == nil { writeWSError(session, 500, "server_error", "failed to create request context") return } // Try native WS upstream first if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) { cancel() return } // Fall back to HTTP bridge h.executeHTTPBridge(session, bifrostCtx, cancel, bifrostReq) } // tryNativeWSUpstream attempts to forward the event to a native WS upstream connection. // Returns true if the event was handled (successfully or with error sent to client). // Returns false if the provider doesn't support WS and we should fall back to HTTP bridge. func (h *WSResponsesHandler) tryNativeWSUpstream( session *bfws.Session, ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest, rawEvent []byte, ) bool { provider := h.client.GetProviderByKey(req.Provider) if provider == nil { return false } wsProvider, ok := provider.(schemas.WebSocketCapableProvider) if !ok || !wsProvider.SupportsWebSocketMode() { return false } key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model) if err != nil { writeWSError(session, 400, "invalid_request_error", err.Error()) return true } wsURL := wsProvider.WebSocketResponsesURL(key) upstream := session.Upstream() // Validate the pinned upstream matches the current request's provider/key if upstream != nil && !upstream.IsClosed() && (upstream.Provider() != req.Provider || upstream.KeyID() != key.ID) { h.pool.Discard(upstream) session.SetUpstream(nil) upstream = nil } // If no upstream connection pinned, get one from the pool or dial if upstream == nil || upstream.IsClosed() { headers := wsProvider.WebSocketHeaders(key) poolKey := bfws.PoolKey{ Provider: req.Provider, KeyID: key.ID, Endpoint: wsURL, } upstream, err = h.pool.Get(poolKey, headers) if err != nil { logger.Warn("failed to get upstream WS connection for %s: %v, falling back to HTTP bridge", req.Provider, err) return false } session.SetUpstream(upstream) } // Run plugin pre-hooks before forwarding to upstream bifrostReq := &schemas.BifrostRequest{ RequestType: schemas.WebSocketResponsesRequest, ResponsesRequest: req, } hooks, preErr := h.client.RunStreamPreHooks(ctx, bifrostReq) if preErr != nil { writeWSBifrostError(session, preErr) return true } defer hooks.Cleanup() // If a plugin short-circuited with a cached response, write it and skip upstream if hooks.ShortCircuitResponse != nil { writeWSShortCircuitResponse(session, hooks.ShortCircuitResponse) return true } // Forward the raw event to upstream if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil { logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err) h.pool.Discard(upstream) session.SetUpstream(nil) return false } // Retrieve tracer and traceID for chunk accumulation tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string) // Read response events from upstream and relay to client, running post-hooks per chunk forwardedAny := false for { msgType, data, readErr := upstream.ReadMessage() if readErr != nil { logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr) h.pool.Discard(upstream) session.SetUpstream(nil) if !forwardedAny { return false } writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted") return true } streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model) isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type) if isTerminal { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) } if streamResp != nil { resp := &schemas.BifrostResponse{ResponsesStreamResponse: streamResp} if tracer != nil && traceID != "" { tracer.AddStreamingChunk(traceID, resp) } _, postErr := hooks.PostHookRunner(ctx, resp, nil) if postErr != nil { h.pool.Discard(upstream) session.SetUpstream(nil) writeWSBifrostError(session, postErr) return true } } if writeErr := session.WriteMessage(msgType, data); writeErr != nil { h.pool.Discard(upstream) session.SetUpstream(nil) return true } forwardedAny = true if isTerminal { h.trackResponseID(session, data) return true } } } // writeWSShortCircuitResponse writes a short-circuited plugin response as WS events. func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) { if resp.ResponsesResponse != nil { data, err := sonic.Marshal(resp.ResponsesResponse) if err != nil { return } if err := session.WriteMessage(ws.TextMessage, data); err != nil { return } if resp.ResponsesResponse.ID != nil && *resp.ResponsesResponse.ID != "" { session.SetLastResponseID(*resp.ResponsesResponse.ID) } } else if resp.ResponsesStreamResponse != nil { data, err := sonic.Marshal(resp.ResponsesStreamResponse) if err != nil { return } session.WriteMessage(ws.TextMessage, data) } } // parseUpstreamWSEvent attempts to parse a raw upstream WS event into a BifrostResponsesStreamResponse. // It populates ExtraFields so downstream plugins (logging, tracing) can identify the request type. // Returns nil if the data cannot be parsed (non-fatal, the raw bytes are still relayed). func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model string) *schemas.BifrostResponsesStreamResponse { var streamResp schemas.BifrostResponsesStreamResponse if err := sonic.Unmarshal(data, &streamResp); err != nil { return nil } if streamResp.Type == "" { return nil } streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest streamResp.ExtraFields.Provider = provider streamResp.ExtraFields.OriginalModelRequested = model return &streamResp } // isTerminalStreamType returns true if the event type signals the end of a response stream. func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool { switch t { case schemas.ResponsesStreamResponseTypeCompleted, schemas.ResponsesStreamResponseTypeFailed, schemas.ResponsesStreamResponseTypeIncomplete, schemas.ResponsesStreamResponseTypeError: return true } return false } // trackResponseID extracts and stores the response ID from terminal events. func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte) { var envelope struct { Response struct { ID string `json:"id"` } `json:"response"` } if err := sonic.Unmarshal(data, &envelope); err == nil && envelope.Response.ID != "" { session.SetLastResponseID(envelope.Response.ID) } } // convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest. func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) { provider, modelName := schemas.ParseModelString(event.Model, "") if provider == "" || modelName == "" { return nil, errModelFormat } var input []schemas.ResponsesMessage if event.Input != nil { // Try parsing as array first if err := sonic.Unmarshal(event.Input, &input); err != nil { // Try as string var inputStr string if strErr := sonic.Unmarshal(event.Input, &inputStr); strErr != nil { return nil, errInputRequired } input = []schemas.ResponsesMessage{ { Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ContentStr: &inputStr}, }, } } } if len(input) == 0 { return nil, errInputRequired } params := &schemas.ResponsesParameters{} if event.Temperature != nil { params.Temperature = event.Temperature } if event.TopP != nil { params.TopP = event.TopP } if event.MaxOutputTokens != nil { params.MaxOutputTokens = event.MaxOutputTokens } if event.Instructions != "" { params.Instructions = &event.Instructions } if event.PreviousResponseID != "" { params.PreviousResponseID = &event.PreviousResponseID } if event.Store != nil { params.Store = event.Store } if event.Tools != nil { var tools []schemas.ResponsesTool if err := sonic.Unmarshal(event.Tools, &tools); err == nil { params.Tools = tools } } if event.ToolChoice != nil { var tc schemas.ResponsesToolChoice if err := sonic.Unmarshal(event.ToolChoice, &tc); err == nil { params.ToolChoice = &tc } } if event.Reasoning != nil { var reasoning schemas.ResponsesParametersReasoning if err := sonic.Unmarshal(event.Reasoning, &reasoning); err == nil { params.Reasoning = &reasoning } } if event.Text != nil { var text schemas.ResponsesTextConfig if err := sonic.Unmarshal(event.Text, &text); err == nil { params.Text = &text } } if event.Metadata != nil { var metadata map[string]any if err := sonic.Unmarshal(event.Metadata, &metadata); err == nil { params.Metadata = &metadata } } if event.Truncation != "" { params.Truncation = &event.Truncation } return &schemas.BifrostResponsesRequest{ Provider: schemas.ModelProvider(provider), Model: modelName, Input: input, Params: params, }, nil } // createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade. func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) { ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background()) if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" { ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID) } if auth.virtualKey != "" { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey) } // Handle Bearer token with sk-bf- prefix (virtual key via Authorization header) if auth.authorization != "" { if strings.HasPrefix(auth.authorization, "Bearer ") { token := strings.TrimPrefix(auth.authorization, "Bearer ") if strings.HasPrefix(token, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-")) } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(token), Models: schemas.WhiteList{"*"}, Weight: 1.0, } ctx.SetValue(schemas.BifrostContextKeyDirectKey, key) } } } if auth.apiKey != "" { if strings.HasPrefix(auth.apiKey, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-")) } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(auth.apiKey), Models: schemas.WhiteList{"*"}, Weight: 1.0, } ctx.SetValue(schemas.BifrostContextKeyDirectKey, key) } } if auth.googAPIKey != "" { if strings.HasPrefix(auth.googAPIKey, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-")) } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(auth.googAPIKey), Models: schemas.WhiteList{"*"}, Weight: 1.0, } ctx.SetValue(schemas.BifrostContextKeyDirectKey, key) } } // Forward x-bf-* headers for k, v := range auth.extraHeaders { lk := strings.ToLower(k) switch { case lk == "x-bf-vk": // Already handled above case lk == "x-bf-api-key": ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v) case strings.HasPrefix(lk, "x-bf-eh-"): suffix := strings.TrimPrefix(lk, "x-bf-eh-") existing, _ := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string]string) if existing == nil { existing = make(map[string]string) } existing[suffix] = v ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing) } } return ctx, cancel } // executeHTTPBridge runs the response through the existing streaming inference pipeline. func (h *WSResponsesHandler) executeHTTPBridge( session *bfws.Session, ctx *schemas.BifrostContext, cancel context.CancelFunc, req *schemas.BifrostResponsesRequest, ) { defer cancel() stream, bifrostErr := h.client.ResponsesStreamRequest(ctx, req) if bifrostErr != nil { writeWSBifrostError(session, bifrostErr) return } // Relay streaming chunks as WS messages for chunk := range stream { if chunk == nil { continue } chunkJSON, err := sonic.Marshal(chunk) if err != nil { logger.Warn("failed to marshal stream chunk: %v", err) continue } if writeErr := session.WriteMessage(ws.TextMessage, chunkJSON); writeErr != nil { return } // Track last response ID for session chaining if chunk.BifrostResponsesStreamResponse != nil && chunk.BifrostResponsesStreamResponse.Response != nil && chunk.BifrostResponsesStreamResponse.Response.ID != nil && *chunk.BifrostResponsesStreamResponse.Response.ID != "" { session.SetLastResponseID(*chunk.BifrostResponsesStreamResponse.Response.ID) } } } // writeWSError sends a JSON error event to a WebSocket write target. // Accepts either a raw *ws.Conn (pre-session) or a *bfws.Session (mutex-protected). func writeWSError(w wsWriter, status int, code, message string) { event := schemas.WebSocketErrorEvent{ Type: schemas.WSEventError, Status: status, Error: &schemas.WebSocketErrorBody{ Code: code, Message: message, }, } data, err := sonic.Marshal(event) if err != nil { return } w.WriteMessage(ws.TextMessage, data) } // writeWSBifrostError converts a BifrostError to a WS error event. func writeWSBifrostError(w wsWriter, bifrostErr *schemas.BifrostError) { status := 500 if bifrostErr.StatusCode != nil && *bifrostErr.StatusCode > 0 { status = *bifrostErr.StatusCode } code := "server_error" msg := "internal server error" if bifrostErr.Error != nil { if bifrostErr.Error.Code != nil && *bifrostErr.Error.Code != "" { code = *bifrostErr.Error.Code } else if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" { code = *bifrostErr.Error.Type } if bifrostErr.Error.Message != "" { msg = bifrostErr.Error.Message } } writeWSError(w, status, code, msg) } // wsResponsesKnownFields lists the fields explicitly handled by WebSocketResponsesEvent. // Anything not in this set is treated as an extra param and forwarded as-is to the provider. var wsResponsesKnownFields = map[string]bool{ "type": true, "model": true, "store": true, "input": true, "instructions": true, "previous_response_id": true, "tools": true, "tool_choice": true, "temperature": true, "top_p": true, "max_output_tokens": true, "reasoning": true, "metadata": true, "text": true, "truncation": true, } var ( errModelFormat = errorf("model should be in provider/model format") errInputRequired = errorf("input is required for responses") ) func errorf(msg string) error { return &simpleError{msg: msg} } type simpleError struct { msg string } func (e *simpleError) Error() string { return e.msg }