package handlers import ( "errors" "io" "net" "strings" "sync" "time" "github.com/fasthttp/router" ws "github.com/fasthttp/websocket" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" "github.com/valyala/fasthttp" ) const ( realtimeWSPingInterval = 15 * time.Second realtimeWSPongTimeout = 45 * time.Second realtimeWSPingWriteTimeout = 10 * time.Second realtimeWSWriteTimeout = 30 * time.Second ) // WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API. type WSRealtimeHandler struct { client *bifrost.Bifrost config *lib.Config handlerStore lib.HandlerStore pool *bfws.Pool sessions *bfws.SessionManager } // NewWSRealtimeHandler creates a new Realtime WebSocket handler. func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler { maxConns := config.WebSocketConfig.MaxConnections return &WSRealtimeHandler{ client: client, config: config, handlerStore: config, pool: pool, sessions: bfws.NewSessionManager(maxConns), } } // RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths. func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...) r.GET("/v1/realtime", handler) for _, path := range integrations.OpenAIRealtimePaths("/openai") { r.GET(path, handler) } } func (h *WSRealtimeHandler) Close() { if h == nil || h.sessions == nil { return } h.sessions.CloseAll() } func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) modelParam := string(ctx.QueryArgs().Peek("model")) deploymentParam := string(ctx.QueryArgs().Peek("deployment")) auth := captureAuthHeaders(ctx) // OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.". // Extract it into the auth headers so downstream processing recognizes it. if auth.authorization == "" { if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" { auth.authorization = "Bearer " + token } } providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam) if err != nil { upgrader := h.websocketUpgrader("") upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) { defer conn.Close() clientConn := newRealtimeClientConn(conn) clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) }) if upgradeErr != nil { logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr) } return } provider := h.client.GetProviderByKey(providerKey) rtProvider, ok := provider.(schemas.RealtimeProvider) if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() { upgrader := h.websocketUpgrader("") upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) { defer conn.Close() clientConn := newRealtimeClientConn(conn) clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey))) }) if upgradeErr != nil { logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr) } return } upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol()) err = upgrader.Upgrade(ctx, func(conn *ws.Conn) { defer conn.Close() clientConn := newRealtimeClientConn(conn) session, sessionErr := h.sessions.Create(conn) if sessionErr != nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error())) return } defer h.sessions.Remove(conn) h.runRealtimeSession(clientConn, session, auth, path, providerKey, model) }) if err != nil { logger.Warn("websocket upgrade failed for %s: %v", path, err) } } func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader { upgrader := ws.FastHTTPUpgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { origin := string(ctx.Request.Header.Peek("Origin")) if origin == "" { return true } return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins) }, } if strings.TrimSpace(subprotocol) != "" { upgrader.Subprotocols = []string{subprotocol} } return upgrader } func (h *WSRealtimeHandler) runRealtimeSession( clientConn *realtimeClientConn, session *bfws.Session, auth *authHeaders, path string, providerKey schemas.ModelProvider, model string, ) { clientConn.startHeartbeat() defer clientConn.stopHeartbeat() bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth) if bifrostCtx == nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context")) return } defer cancel() // Resolve ephemeral key mapping to restore virtual key context. token := extractRealtimeBearerTokenFromHeader(auth.authorization) if isRealtimeEphemeralToken(token) { mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token) if ok { applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping) } } bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) if strings.HasPrefix(path, "/openai") { bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") } provider := h.client.GetProviderByKey(providerKey) if provider == nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey))) return } rtProvider, ok := provider.(schemas.RealtimeProvider) if !ok || !rtProvider.SupportsRealtimeAPI() { clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey))) return } key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) if err != nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) return } // Resolve model alias so the provider receives the actual model identifier. model = key.Aliases.Resolve(model) wsURL := rtProvider.RealtimeWebSocketURL(key, model) upstream, err := h.pool.Get(bfws.PoolKey{ Provider: providerKey, KeyID: key.ID, Endpoint: wsURL, }, rtProvider.RealtimeHeaders(key)) if err != nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error())) return } defer h.pool.Discard(upstream) errCh := make(chan error, 2) go func() { errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key) }() go func() { errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key) }() firstErr := <-errCh _ = upstream.Close() _ = clientConn.Close() secondErr := <-errCh if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil { logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr) } } func (h *WSRealtimeHandler) relayClientToRealtimeProvider( clientConn *realtimeClientConn, session *bfws.Session, upstream *bfws.UpstreamConn, provider schemas.RealtimeProvider, bifrostCtx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string, key schemas.Key, ) error { for { messageType, message, err := clientConn.ReadMessage() if err != nil { finalizeRealtimeTurnHooksOnTransportError( h.client, bifrostCtx, session, providerKey, model, &key, 499, "client_closed_request", "client realtime websocket disconnected before turn completed", ) if isNormalWebSocketClosure(err) { return nil } return err } if messageType != ws.TextMessage { clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages")) return nil } event, err := schemas.ParseRealtimeEvent(message) if err != nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON")) continue } // Extract pending tool/input summaries but defer recording until the event // passes validation — rejected events must not pollute session state. toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event) inputItemID, inputSummary := pendingRealtimeInputUpdate(event) startsTurn := provider.ShouldStartRealtimeTurn(event) if startsTurn { if session.PeekRealtimeTurnHooks() != nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress.")) continue } if toolSummary != "" { session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message)) } if inputSummary != "" { session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) } if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil { clientConn.writeRealtimeError(bifrostErr) return nil } } providerEvent, err := provider.ToProviderRealtimeEvent(event) if err != nil { if startsTurn { if finalizeErr := finalizeRealtimeTurnHooksWithError( h.client, bifrostCtx, session, providerKey, model, &key, schemas.RTEventError, nil, newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()), ); finalizeErr != nil { clientConn.writeRealtimeError(finalizeErr) return nil } } clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) continue } // Record tool output / input only after the event passed validation. if !startsTurn { if toolSummary != "" { session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message)) } if inputSummary != "" { session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) } } if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil { finalizeRealtimeTurnHooksWithError( h.client, bifrostCtx, session, providerKey, model, &key, schemas.RTEventError, nil, newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"), ) clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream")) return err } } } func (h *WSRealtimeHandler) relayRealtimeProviderToClient( clientConn *realtimeClientConn, session *bfws.Session, upstream *bfws.UpstreamConn, provider schemas.RealtimeProvider, bifrostCtx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string, key schemas.Key, ) error { for { disconnectAfterWrite := false messageType, message, err := upstream.ReadMessage() if err != nil { finalizeRealtimeTurnHooksOnTransportError( h.client, bifrostCtx, session, providerKey, model, &key, 502, "upstream_connection_error", "upstream realtime websocket closed before turn completed", ) if isNormalWebSocketClosure(err) { return nil } finalizeRealtimeTurnHooksWithError( h.client, bifrostCtx, session, providerKey, model, &key, schemas.RTEventError, nil, newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"), ) clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted")) return err } if messageType == ws.TextMessage { event, err := provider.ToBifrostRealtimeEvent(message) if err != nil { finalizeRealtimeTurnHooksWithError( h.client, bifrostCtx, session, providerKey, model, &key, schemas.RTEventError, message, newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"), ) clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event")) return err } if event != nil { if event.Session != nil && event.Session.ID != "" { session.SetProviderSessionID(event.Session.ID) } if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) { session.AppendRealtimeOutputText(event.Delta.Text) session.AppendRealtimeOutputText(event.Delta.Transcript) } if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil { if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil { clientConn.writeRealtimeError(bifrostErr) return nil } } } if event != nil { inputItemID, inputSummary := pendingRealtimeInputUpdate(event) if !provider.ShouldForwardRealtimeEvent(event) { continue } if event.Type == provider.RealtimeTurnFinalEvent() { contentOverride := session.ConsumeRealtimeOutputText() if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil { clientConn.writeRealtimeError(bifrostErr) return nil } } else if event.Error != nil { turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error) finalizeErr := finalizeRealtimeTurnHooksWithError( h.client, bifrostCtx, session, providerKey, model, &key, event.Type, message, turnErr, ) if finalizeErr != nil { clientConn.writeRealtimeError(finalizeErr) return nil } // Defer the disconnect so the normal translated-write path // below still runs — otherwise terminal errors from translated // providers would reach the client in provider-native format. disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr) } else if inputSummary != "" { session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) } if len(event.RawData) == 0 { message, err = provider.ToProviderRealtimeEvent(event) if err != nil { clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event")) return err } } } } if err := clientConn.WriteMessage(messageType, message); err != nil { finalizeRealtimeTurnHooksOnTransportError( h.client, bifrostCtx, session, providerKey, model, &key, 499, "client_closed_request", "client realtime websocket disconnected before turn completed", ) if isNormalWebSocketClosure(err) { return nil } return err } if disconnectAfterWrite { return nil } } } func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) { defaultProvider := realtimeDefaultProviderForPath(path) switch { case strings.TrimSpace(modelParam) != "": provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider) if provider == "" || strings.TrimSpace(model) == "" { return "", "", errRealtimeModelFormat } return provider, strings.TrimSpace(model), nil case strings.TrimSpace(deploymentParam) != "": provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider) if provider == "" || strings.TrimSpace(model) == "" { return "", "", errRealtimeDeploymentFormat } return provider, strings.TrimSpace(model), nil default: return "", "", errRealtimeModelRequired } } func realtimeDefaultProviderForPath(path string) schemas.ModelProvider { if strings.HasPrefix(path, "/openai/") { return schemas.OpenAI } return "" } func isNormalWebSocketClosure(err error) bool { return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived) } func isExpectedRealtimeRelayShutdown(err error) bool { if err == nil { return true } if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { return true } // Relay teardown closes the opposite socket after the first side exits, which can // surface as a plain network-close read error instead of a websocket close frame. return strings.Contains(err.Error(), "use of closed network connection") } func selectRealtimeRelayError(errs ...error) error { for _, err := range errs { if err != nil && !isExpectedRealtimeRelayShutdown(err) { return err } } return nil } var ( errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket") errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket") errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket") ) type realtimeClientConn struct { conn *ws.Conn writeMu sync.Mutex closeOnce sync.Once done chan struct{} } func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn { return &realtimeClientConn{ conn: conn, done: make(chan struct{}), } } func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) { messageType, p, err = c.conn.ReadMessage() if err == nil { c.refreshReadDeadline() } return messageType, p, err } func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error { c.writeMu.Lock() defer c.writeMu.Unlock() if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil { return err } if err := c.conn.WriteMessage(messageType, data); err != nil { return err } return c.conn.SetWriteDeadline(time.Time{}) } func (c *realtimeClientConn) startHeartbeat() { c.installPongHandler() c.refreshReadDeadline() go func() { ticker := time.NewTicker(realtimeWSPingInterval) defer ticker.Stop() for { select { case <-ticker.C: if err := c.writePing(); err != nil { _ = c.Close() return } case <-c.done: return } } }() } func (c *realtimeClientConn) stopHeartbeat() { c.closeDone() } func (c *realtimeClientConn) installPongHandler() { c.conn.SetPongHandler(func(string) error { return c.refreshReadDeadline() }) } func (c *realtimeClientConn) refreshReadDeadline() error { return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout)) } func (c *realtimeClientConn) writePing() error { c.writeMu.Lock() defer c.writeMu.Unlock() if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil { return err } if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil { return err } return c.conn.SetWriteDeadline(time.Time{}) } func (c *realtimeClientConn) closeDone() { c.closeOnce.Do(func() { close(c.done) }) } func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) { payload := newRealtimeTurnErrorEventPayload(bifrostErr) _ = c.WriteMessage(ws.TextMessage, payload) } func (c *realtimeClientConn) Close() error { c.closeDone() return c.conn.Close() } const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key." // extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol // header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.". func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string { header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol")) for _, proto := range strings.Split(header, ",") { proto = strings.TrimSpace(proto) if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) { return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix) } } return "" } func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError { errType := code return &schemas.BifrostError{ StatusCode: &status, Type: &errType, Error: &schemas.ErrorField{ Type: &errType, Code: &errType, Message: message, }, } }