Files
bifrost/transports/bifrost-http/handlers/wsresponses.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

703 lines
22 KiB
Go

package handlers
import (
"context"
"strings"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
ws "github.com/fasthttp/websocket"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
"github.com/valyala/fasthttp"
)
// wsWriter abstracts a WebSocket write target. Both *ws.Conn (pre-session)
// and *bfws.Session (post-session, mutex-protected) satisfy this interface.
type wsWriter interface {
WriteMessage(messageType int, data []byte) error
}
// WSResponsesHandler handles WebSocket connections for the Responses API WebSocket Mode.
// Clients connect via `GET /v1/responses` with a WS upgrade and send `response.create` events.
// Each event is routed through the standard Bifrost inference pipeline (PreLLMHook, key selection,
// provider call, PostLLMHook) via the HTTP bridge, with native WS upstream as an optimization.
type WSResponsesHandler struct {
client *bifrost.Bifrost
config *lib.Config
handlerStore lib.HandlerStore
pool *bfws.Pool
sessions *bfws.SessionManager
upgrader ws.FastHTTPUpgrader
}
// NewWSResponsesHandler creates a new WebSocket Responses handler.
func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSResponsesHandler {
maxConns := config.WebSocketConfig.MaxConnections
return &WSResponsesHandler{
client: client,
config: config,
handlerStore: config,
pool: pool,
sessions: bfws.NewSessionManager(maxConns),
upgrader: ws.FastHTTPUpgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if origin == "" {
return true
}
return IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins)
},
},
}
}
// Close gracefully shuts down all active WebSocket responses sessions.
func (h *WSResponsesHandler) Close() {
if h == nil || h.sessions == nil {
return
}
h.sessions.CloseAll()
}
// RegisterRoutes registers the WebSocket Responses endpoint at the base path
// and all OpenAI integration paths.
func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
// Base path (outside integration prefix)
r.GET("/v1/responses", handler)
// OpenAI integration paths (/openai/v1/responses, /openai/responses, /openai/openai/responses)
for _, path := range integrations.OpenAIWSResponsesPaths("/openai") {
r.GET(path, handler)
}
}
// handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop.
func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
session, sessionErr := h.sessions.Create(conn)
if sessionErr != nil {
writeWSError(conn, 429, "websocket_connection_limit_reached", sessionErr.Error())
return
}
defer h.sessions.Remove(conn)
// Capture auth headers from the upgrade request for per-event context creation
authHeaders := captureAuthHeaders(ctx)
h.eventLoop(conn, session, authHeaders)
})
if err != nil {
logger.Warn("websocket upgrade failed for /v1/responses: %v", err)
}
}
// authHeaders holds auth-related headers captured during the WS upgrade.
type authHeaders struct {
authorization string
virtualKey string
apiKey string
googAPIKey string
baggage string
extraHeaders map[string]string
}
// captureAuthHeaders captures the auth headers from the request.
func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders {
ah := &authHeaders{
authorization: string(ctx.Request.Header.Peek("Authorization")),
virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")),
apiKey: string(ctx.Request.Header.Peek("x-api-key")),
googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")),
baggage: string(ctx.Request.Header.Peek("baggage")),
extraHeaders: make(map[string]string),
}
for key, value := range ctx.Request.Header.All() {
k := string(key)
lk := strings.ToLower(k)
if strings.HasPrefix(lk, "x-bf-") {
ah.extraHeaders[k] = string(value)
}
}
return ah
}
// eventLoop reads events from the client WebSocket and processes them.
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) {
for {
_, message, err := conn.ReadMessage()
if err != nil {
if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseNormalClosure) {
logger.Warn("websocket read error: %v", err)
}
return
}
// Parse the event type
var envelope struct {
Type string `json:"type"`
}
if err := sonic.Unmarshal(message, &envelope); err != nil {
writeWSError(session, 400, "invalid_request_error", "failed to parse event JSON")
continue
}
switch schemas.WebSocketEventType(envelope.Type) {
case schemas.WSEventResponseCreate:
h.handleResponseCreate(session, auth, message)
default:
writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type)
}
}
}
// handleResponseCreate processes a response.create event.
// Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge.
// If native WS upstream fails mid-stream, falls back to HTTP bridge.
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) {
var event schemas.WebSocketResponsesEvent
if err := sonic.Unmarshal(message, &event); err != nil {
writeWSError(session, 400, "invalid_request_error", "failed to parse response.create event")
return
}
// Store override: default to store=true (Codex sends false by default but expects true).
// If DisableStore is set in provider config, force store=false.
// If client explicitly sets store, respect that value unless DisableStore overrides it.
provider, modelName := schemas.ParseModelString(event.Model, "")
if provider == "" || modelName == "" {
writeWSError(session, 400, "invalid_request_error", "failed to parse model string")
return
}
if providerCfg, cfgErr := h.config.GetProviderConfigRaw(provider); cfgErr == nil &&
providerCfg.OpenAIConfig != nil && providerCfg.OpenAIConfig.DisableStore {
event.Store = schemas.Ptr(false)
} else {
event.Store = schemas.Ptr(true)
}
bifrostReq, err := h.convertEventToRequest(&event)
if err != nil {
writeWSError(session, 400, "invalid_request_error", err.Error())
return
}
// Extract extra params (unknown fields) and forward them, matching the HTTP path behavior
extraParams, extractErr := extractExtraParams(message, wsResponsesKnownFields)
if extractErr == nil && len(extraParams) > 0 {
if bifrostReq.Params == nil {
bifrostReq.Params = &schemas.ResponsesParameters{}
}
bifrostReq.Params.ExtraParams = extraParams
}
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
if bifrostCtx == nil {
writeWSError(session, 500, "server_error", "failed to create request context")
return
}
// Try native WS upstream first
if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) {
cancel()
return
}
// Fall back to HTTP bridge
h.executeHTTPBridge(session, bifrostCtx, cancel, bifrostReq)
}
// tryNativeWSUpstream attempts to forward the event to a native WS upstream connection.
// Returns true if the event was handled (successfully or with error sent to client).
// Returns false if the provider doesn't support WS and we should fall back to HTTP bridge.
func (h *WSResponsesHandler) tryNativeWSUpstream(
session *bfws.Session,
ctx *schemas.BifrostContext,
req *schemas.BifrostResponsesRequest,
rawEvent []byte,
) bool {
provider := h.client.GetProviderByKey(req.Provider)
if provider == nil {
return false
}
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
if !ok || !wsProvider.SupportsWebSocketMode() {
return false
}
key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model)
if err != nil {
writeWSError(session, 400, "invalid_request_error", err.Error())
return true
}
wsURL := wsProvider.WebSocketResponsesURL(key)
upstream := session.Upstream()
// Validate the pinned upstream matches the current request's provider/key
if upstream != nil && !upstream.IsClosed() &&
(upstream.Provider() != req.Provider || upstream.KeyID() != key.ID) {
h.pool.Discard(upstream)
session.SetUpstream(nil)
upstream = nil
}
// If no upstream connection pinned, get one from the pool or dial
if upstream == nil || upstream.IsClosed() {
headers := wsProvider.WebSocketHeaders(key)
poolKey := bfws.PoolKey{
Provider: req.Provider,
KeyID: key.ID,
Endpoint: wsURL,
}
upstream, err = h.pool.Get(poolKey, headers)
if err != nil {
logger.Warn("failed to get upstream WS connection for %s: %v, falling back to HTTP bridge", req.Provider, err)
return false
}
session.SetUpstream(upstream)
}
// Run plugin pre-hooks before forwarding to upstream
bifrostReq := &schemas.BifrostRequest{
RequestType: schemas.WebSocketResponsesRequest,
ResponsesRequest: req,
}
hooks, preErr := h.client.RunStreamPreHooks(ctx, bifrostReq)
if preErr != nil {
writeWSBifrostError(session, preErr)
return true
}
defer hooks.Cleanup()
// If a plugin short-circuited with a cached response, write it and skip upstream
if hooks.ShortCircuitResponse != nil {
writeWSShortCircuitResponse(session, hooks.ShortCircuitResponse)
return true
}
// Forward the raw event to upstream
if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil {
logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err)
h.pool.Discard(upstream)
session.SetUpstream(nil)
return false
}
// Retrieve tracer and traceID for chunk accumulation
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
// Read response events from upstream and relay to client, running post-hooks per chunk
forwardedAny := false
for {
msgType, data, readErr := upstream.ReadMessage()
if readErr != nil {
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
h.pool.Discard(upstream)
session.SetUpstream(nil)
if !forwardedAny {
return false
}
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")
return true
}
streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model)
isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type)
if isTerminal {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
}
if streamResp != nil {
resp := &schemas.BifrostResponse{ResponsesStreamResponse: streamResp}
if tracer != nil && traceID != "" {
tracer.AddStreamingChunk(traceID, resp)
}
_, postErr := hooks.PostHookRunner(ctx, resp, nil)
if postErr != nil {
h.pool.Discard(upstream)
session.SetUpstream(nil)
writeWSBifrostError(session, postErr)
return true
}
}
if writeErr := session.WriteMessage(msgType, data); writeErr != nil {
h.pool.Discard(upstream)
session.SetUpstream(nil)
return true
}
forwardedAny = true
if isTerminal {
h.trackResponseID(session, data)
return true
}
}
}
// writeWSShortCircuitResponse writes a short-circuited plugin response as WS events.
func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) {
if resp.ResponsesResponse != nil {
data, err := sonic.Marshal(resp.ResponsesResponse)
if err != nil {
return
}
if err := session.WriteMessage(ws.TextMessage, data); err != nil {
return
}
if resp.ResponsesResponse.ID != nil && *resp.ResponsesResponse.ID != "" {
session.SetLastResponseID(*resp.ResponsesResponse.ID)
}
} else if resp.ResponsesStreamResponse != nil {
data, err := sonic.Marshal(resp.ResponsesStreamResponse)
if err != nil {
return
}
session.WriteMessage(ws.TextMessage, data)
}
}
// parseUpstreamWSEvent attempts to parse a raw upstream WS event into a BifrostResponsesStreamResponse.
// It populates ExtraFields so downstream plugins (logging, tracing) can identify the request type.
// Returns nil if the data cannot be parsed (non-fatal, the raw bytes are still relayed).
func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model string) *schemas.BifrostResponsesStreamResponse {
var streamResp schemas.BifrostResponsesStreamResponse
if err := sonic.Unmarshal(data, &streamResp); err != nil {
return nil
}
if streamResp.Type == "" {
return nil
}
streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest
streamResp.ExtraFields.Provider = provider
streamResp.ExtraFields.OriginalModelRequested = model
return &streamResp
}
// isTerminalStreamType returns true if the event type signals the end of a response stream.
func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool {
switch t {
case schemas.ResponsesStreamResponseTypeCompleted,
schemas.ResponsesStreamResponseTypeFailed,
schemas.ResponsesStreamResponseTypeIncomplete,
schemas.ResponsesStreamResponseTypeError:
return true
}
return false
}
// trackResponseID extracts and stores the response ID from terminal events.
func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte) {
var envelope struct {
Response struct {
ID string `json:"id"`
} `json:"response"`
}
if err := sonic.Unmarshal(data, &envelope); err == nil && envelope.Response.ID != "" {
session.SetLastResponseID(envelope.Response.ID)
}
}
// convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest.
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) {
provider, modelName := schemas.ParseModelString(event.Model, "")
if provider == "" || modelName == "" {
return nil, errModelFormat
}
var input []schemas.ResponsesMessage
if event.Input != nil {
// Try parsing as array first
if err := sonic.Unmarshal(event.Input, &input); err != nil {
// Try as string
var inputStr string
if strErr := sonic.Unmarshal(event.Input, &inputStr); strErr != nil {
return nil, errInputRequired
}
input = []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{ContentStr: &inputStr},
},
}
}
}
if len(input) == 0 {
return nil, errInputRequired
}
params := &schemas.ResponsesParameters{}
if event.Temperature != nil {
params.Temperature = event.Temperature
}
if event.TopP != nil {
params.TopP = event.TopP
}
if event.MaxOutputTokens != nil {
params.MaxOutputTokens = event.MaxOutputTokens
}
if event.Instructions != "" {
params.Instructions = &event.Instructions
}
if event.PreviousResponseID != "" {
params.PreviousResponseID = &event.PreviousResponseID
}
if event.Store != nil {
params.Store = event.Store
}
if event.Tools != nil {
var tools []schemas.ResponsesTool
if err := sonic.Unmarshal(event.Tools, &tools); err == nil {
params.Tools = tools
}
}
if event.ToolChoice != nil {
var tc schemas.ResponsesToolChoice
if err := sonic.Unmarshal(event.ToolChoice, &tc); err == nil {
params.ToolChoice = &tc
}
}
if event.Reasoning != nil {
var reasoning schemas.ResponsesParametersReasoning
if err := sonic.Unmarshal(event.Reasoning, &reasoning); err == nil {
params.Reasoning = &reasoning
}
}
if event.Text != nil {
var text schemas.ResponsesTextConfig
if err := sonic.Unmarshal(event.Text, &text); err == nil {
params.Text = &text
}
}
if event.Metadata != nil {
var metadata map[string]any
if err := sonic.Unmarshal(event.Metadata, &metadata); err == nil {
params.Metadata = &metadata
}
}
if event.Truncation != "" {
params.Truncation = &event.Truncation
}
return &schemas.BifrostResponsesRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: input,
Params: params,
}, nil
}
// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade.
func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) {
ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background())
if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" {
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
}
if auth.virtualKey != "" {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey)
}
// Handle Bearer token with sk-bf- prefix (virtual key via Authorization header)
if auth.authorization != "" {
if strings.HasPrefix(auth.authorization, "Bearer ") {
token := strings.TrimPrefix(auth.authorization, "Bearer ")
if strings.HasPrefix(token, "sk-bf-") {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-"))
} else if handlerStore.ShouldAllowDirectKeys() {
key := schemas.Key{
ID: "header-provided",
Value: *schemas.NewEnvVar(token),
Models: schemas.WhiteList{"*"},
Weight: 1.0,
}
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
}
if auth.apiKey != "" {
if strings.HasPrefix(auth.apiKey, "sk-bf-") {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-"))
} else if handlerStore.ShouldAllowDirectKeys() {
key := schemas.Key{
ID: "header-provided",
Value: *schemas.NewEnvVar(auth.apiKey),
Models: schemas.WhiteList{"*"},
Weight: 1.0,
}
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
if auth.googAPIKey != "" {
if strings.HasPrefix(auth.googAPIKey, "sk-bf-") {
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-"))
} else if handlerStore.ShouldAllowDirectKeys() {
key := schemas.Key{
ID: "header-provided",
Value: *schemas.NewEnvVar(auth.googAPIKey),
Models: schemas.WhiteList{"*"},
Weight: 1.0,
}
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
}
}
// Forward x-bf-* headers
for k, v := range auth.extraHeaders {
lk := strings.ToLower(k)
switch {
case lk == "x-bf-vk":
// Already handled above
case lk == "x-bf-api-key":
ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v)
case strings.HasPrefix(lk, "x-bf-eh-"):
suffix := strings.TrimPrefix(lk, "x-bf-eh-")
existing, _ := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string]string)
if existing == nil {
existing = make(map[string]string)
}
existing[suffix] = v
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing)
}
}
return ctx, cancel
}
// executeHTTPBridge runs the response through the existing streaming inference pipeline.
func (h *WSResponsesHandler) executeHTTPBridge(
session *bfws.Session,
ctx *schemas.BifrostContext,
cancel context.CancelFunc,
req *schemas.BifrostResponsesRequest,
) {
defer cancel()
stream, bifrostErr := h.client.ResponsesStreamRequest(ctx, req)
if bifrostErr != nil {
writeWSBifrostError(session, bifrostErr)
return
}
// Relay streaming chunks as WS messages
for chunk := range stream {
if chunk == nil {
continue
}
chunkJSON, err := sonic.Marshal(chunk)
if err != nil {
logger.Warn("failed to marshal stream chunk: %v", err)
continue
}
if writeErr := session.WriteMessage(ws.TextMessage, chunkJSON); writeErr != nil {
return
}
// Track last response ID for session chaining
if chunk.BifrostResponsesStreamResponse != nil &&
chunk.BifrostResponsesStreamResponse.Response != nil &&
chunk.BifrostResponsesStreamResponse.Response.ID != nil &&
*chunk.BifrostResponsesStreamResponse.Response.ID != "" {
session.SetLastResponseID(*chunk.BifrostResponsesStreamResponse.Response.ID)
}
}
}
// writeWSError sends a JSON error event to a WebSocket write target.
// Accepts either a raw *ws.Conn (pre-session) or a *bfws.Session (mutex-protected).
func writeWSError(w wsWriter, status int, code, message string) {
event := schemas.WebSocketErrorEvent{
Type: schemas.WSEventError,
Status: status,
Error: &schemas.WebSocketErrorBody{
Code: code,
Message: message,
},
}
data, err := sonic.Marshal(event)
if err != nil {
return
}
w.WriteMessage(ws.TextMessage, data)
}
// writeWSBifrostError converts a BifrostError to a WS error event.
func writeWSBifrostError(w wsWriter, bifrostErr *schemas.BifrostError) {
status := 500
if bifrostErr.StatusCode != nil && *bifrostErr.StatusCode > 0 {
status = *bifrostErr.StatusCode
}
code := "server_error"
msg := "internal server error"
if bifrostErr.Error != nil {
if bifrostErr.Error.Code != nil && *bifrostErr.Error.Code != "" {
code = *bifrostErr.Error.Code
} else if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
code = *bifrostErr.Error.Type
}
if bifrostErr.Error.Message != "" {
msg = bifrostErr.Error.Message
}
}
writeWSError(w, status, code, msg)
}
// wsResponsesKnownFields lists the fields explicitly handled by WebSocketResponsesEvent.
// Anything not in this set is treated as an extra param and forwarded as-is to the provider.
var wsResponsesKnownFields = map[string]bool{
"type": true,
"model": true,
"store": true,
"input": true,
"instructions": true,
"previous_response_id": true,
"tools": true,
"tool_choice": true,
"temperature": true,
"top_p": true,
"max_output_tokens": true,
"reasoning": true,
"metadata": true,
"text": true,
"truncation": true,
}
var (
errModelFormat = errorf("model should be in provider/model format")
errInputRequired = errorf("input is required for responses")
)
func errorf(msg string) error {
return &simpleError{msg: msg}
}
type simpleError struct {
msg string
}
func (e *simpleError) Error() string {
return e.msg
}