Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

667 lines
20 KiB
Go

package handlers
import (
"errors"
"io"
"net"
"strings"
"sync"
"time"
"github.com/fasthttp/router"
ws "github.com/fasthttp/websocket"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
"github.com/valyala/fasthttp"
)
const (
realtimeWSPingInterval = 15 * time.Second
realtimeWSPongTimeout = 45 * time.Second
realtimeWSPingWriteTimeout = 10 * time.Second
realtimeWSWriteTimeout = 30 * time.Second
)
// WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API.
type WSRealtimeHandler struct {
client *bifrost.Bifrost
config *lib.Config
handlerStore lib.HandlerStore
pool *bfws.Pool
sessions *bfws.SessionManager
}
// NewWSRealtimeHandler creates a new Realtime WebSocket handler.
func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler {
maxConns := config.WebSocketConfig.MaxConnections
return &WSRealtimeHandler{
client: client,
config: config,
handlerStore: config,
pool: pool,
sessions: bfws.NewSessionManager(maxConns),
}
}
// RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths.
func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
r.GET("/v1/realtime", handler)
for _, path := range integrations.OpenAIRealtimePaths("/openai") {
r.GET(path, handler)
}
}
func (h *WSRealtimeHandler) Close() {
if h == nil || h.sessions == nil {
return
}
h.sessions.CloseAll()
}
func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
modelParam := string(ctx.QueryArgs().Peek("model"))
deploymentParam := string(ctx.QueryArgs().Peek("deployment"))
auth := captureAuthHeaders(ctx)
// OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.<key>".
// Extract it into the auth headers so downstream processing recognizes it.
if auth.authorization == "" {
if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" {
auth.authorization = "Bearer " + token
}
}
providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam)
if err != nil {
upgrader := h.websocketUpgrader("")
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
clientConn := newRealtimeClientConn(conn)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
})
if upgradeErr != nil {
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
}
return
}
provider := h.client.GetProviderByKey(providerKey)
rtProvider, ok := provider.(schemas.RealtimeProvider)
if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() {
upgrader := h.websocketUpgrader("")
upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
clientConn := newRealtimeClientConn(conn)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
})
if upgradeErr != nil {
logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr)
}
return
}
upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol())
err = upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()
clientConn := newRealtimeClientConn(conn)
session, sessionErr := h.sessions.Create(conn)
if sessionErr != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error()))
return
}
defer h.sessions.Remove(conn)
h.runRealtimeSession(clientConn, session, auth, path, providerKey, model)
})
if err != nil {
logger.Warn("websocket upgrade failed for %s: %v", path, err)
}
}
func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader {
upgrader := ws.FastHTTPUpgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
origin := string(ctx.Request.Header.Peek("Origin"))
if origin == "" {
return true
}
return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins)
},
}
if strings.TrimSpace(subprotocol) != "" {
upgrader.Subprotocols = []string{subprotocol}
}
return upgrader
}
func (h *WSRealtimeHandler) runRealtimeSession(
clientConn *realtimeClientConn,
session *bfws.Session,
auth *authHeaders,
path string,
providerKey schemas.ModelProvider,
model string,
) {
clientConn.startHeartbeat()
defer clientConn.stopHeartbeat()
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
if bifrostCtx == nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context"))
return
}
defer cancel()
// Resolve ephemeral key mapping to restore virtual key context.
token := extractRealtimeBearerTokenFromHeader(auth.authorization)
if isRealtimeEphemeralToken(token) {
mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token)
if ok {
applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping)
}
}
bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest)
if strings.HasPrefix(path, "/openai") {
bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai")
}
provider := h.client.GetProviderByKey(providerKey)
if provider == nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey)))
return
}
rtProvider, ok := provider.(schemas.RealtimeProvider)
if !ok || !rtProvider.SupportsRealtimeAPI() {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey)))
return
}
key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model)
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
return
}
// Resolve model alias so the provider receives the actual model identifier.
model = key.Aliases.Resolve(model)
wsURL := rtProvider.RealtimeWebSocketURL(key, model)
upstream, err := h.pool.Get(bfws.PoolKey{
Provider: providerKey,
KeyID: key.ID,
Endpoint: wsURL,
}, rtProvider.RealtimeHeaders(key))
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error()))
return
}
defer h.pool.Discard(upstream)
errCh := make(chan error, 2)
go func() {
errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
}()
go func() {
errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key)
}()
firstErr := <-errCh
_ = upstream.Close()
_ = clientConn.Close()
secondErr := <-errCh
if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil {
logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr)
}
}
func (h *WSRealtimeHandler) relayClientToRealtimeProvider(
clientConn *realtimeClientConn,
session *bfws.Session,
upstream *bfws.UpstreamConn,
provider schemas.RealtimeProvider,
bifrostCtx *schemas.BifrostContext,
providerKey schemas.ModelProvider,
model string,
key schemas.Key,
) error {
for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
finalizeRealtimeTurnHooksOnTransportError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
499,
"client_closed_request",
"client realtime websocket disconnected before turn completed",
)
if isNormalWebSocketClosure(err) {
return nil
}
return err
}
if messageType != ws.TextMessage {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages"))
return nil
}
event, err := schemas.ParseRealtimeEvent(message)
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON"))
continue
}
// Extract pending tool/input summaries but defer recording until the event
// passes validation — rejected events must not pollute session state.
toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event)
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
startsTurn := provider.ShouldStartRealtimeTurn(event)
if startsTurn {
if session.PeekRealtimeTurnHooks() != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress."))
continue
}
if toolSummary != "" {
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
}
if inputSummary != "" {
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
}
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
clientConn.writeRealtimeError(bifrostErr)
return nil
}
}
providerEvent, err := provider.ToProviderRealtimeEvent(event)
if err != nil {
if startsTurn {
if finalizeErr := finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()),
); finalizeErr != nil {
clientConn.writeRealtimeError(finalizeErr)
return nil
}
}
clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))
continue
}
// Record tool output / input only after the event passed validation.
if !startsTurn {
if toolSummary != "" {
session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message))
}
if inputSummary != "" {
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
}
}
if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil {
finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"),
)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"))
return err
}
}
}
func (h *WSRealtimeHandler) relayRealtimeProviderToClient(
clientConn *realtimeClientConn,
session *bfws.Session,
upstream *bfws.UpstreamConn,
provider schemas.RealtimeProvider,
bifrostCtx *schemas.BifrostContext,
providerKey schemas.ModelProvider,
model string,
key schemas.Key,
) error {
for {
disconnectAfterWrite := false
messageType, message, err := upstream.ReadMessage()
if err != nil {
finalizeRealtimeTurnHooksOnTransportError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
502,
"upstream_connection_error",
"upstream realtime websocket closed before turn completed",
)
if isNormalWebSocketClosure(err) {
return nil
}
finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
nil,
newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"),
)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"))
return err
}
if messageType == ws.TextMessage {
event, err := provider.ToBifrostRealtimeEvent(message)
if err != nil {
finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
schemas.RTEventError,
message,
newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"),
)
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"))
return err
}
if event != nil {
if event.Session != nil && event.Session.ID != "" {
session.SetProviderSessionID(event.Session.ID)
}
if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) {
session.AppendRealtimeOutputText(event.Delta.Text)
session.AppendRealtimeOutputText(event.Delta.Transcript)
}
if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil {
if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil {
clientConn.writeRealtimeError(bifrostErr)
return nil
}
}
}
if event != nil {
inputItemID, inputSummary := pendingRealtimeInputUpdate(event)
if !provider.ShouldForwardRealtimeEvent(event) {
continue
}
if event.Type == provider.RealtimeTurnFinalEvent() {
contentOverride := session.ConsumeRealtimeOutputText()
if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil {
clientConn.writeRealtimeError(bifrostErr)
return nil
}
} else if event.Error != nil {
turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error)
finalizeErr := finalizeRealtimeTurnHooksWithError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
event.Type,
message,
turnErr,
)
if finalizeErr != nil {
clientConn.writeRealtimeError(finalizeErr)
return nil
}
// Defer the disconnect so the normal translated-write path
// below still runs — otherwise terminal errors from translated
// providers would reach the client in provider-native format.
disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr)
} else if inputSummary != "" {
session.RecordRealtimeInput(inputItemID, inputSummary, string(message))
}
if len(event.RawData) == 0 {
message, err = provider.ToProviderRealtimeEvent(event)
if err != nil {
clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event"))
return err
}
}
}
}
if err := clientConn.WriteMessage(messageType, message); err != nil {
finalizeRealtimeTurnHooksOnTransportError(
h.client,
bifrostCtx,
session,
providerKey,
model,
&key,
499,
"client_closed_request",
"client realtime websocket disconnected before turn completed",
)
if isNormalWebSocketClosure(err) {
return nil
}
return err
}
if disconnectAfterWrite {
return nil
}
}
}
func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) {
defaultProvider := realtimeDefaultProviderForPath(path)
switch {
case strings.TrimSpace(modelParam) != "":
provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider)
if provider == "" || strings.TrimSpace(model) == "" {
return "", "", errRealtimeModelFormat
}
return provider, strings.TrimSpace(model), nil
case strings.TrimSpace(deploymentParam) != "":
provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider)
if provider == "" || strings.TrimSpace(model) == "" {
return "", "", errRealtimeDeploymentFormat
}
return provider, strings.TrimSpace(model), nil
default:
return "", "", errRealtimeModelRequired
}
}
func realtimeDefaultProviderForPath(path string) schemas.ModelProvider {
if strings.HasPrefix(path, "/openai/") {
return schemas.OpenAI
}
return ""
}
func isNormalWebSocketClosure(err error) bool {
return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived)
}
func isExpectedRealtimeRelayShutdown(err error) bool {
if err == nil {
return true
}
if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return true
}
// Relay teardown closes the opposite socket after the first side exits, which can
// surface as a plain network-close read error instead of a websocket close frame.
return strings.Contains(err.Error(), "use of closed network connection")
}
func selectRealtimeRelayError(errs ...error) error {
for _, err := range errs {
if err != nil && !isExpectedRealtimeRelayShutdown(err) {
return err
}
}
return nil
}
var (
errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket")
errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket")
errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket")
)
type realtimeClientConn struct {
conn *ws.Conn
writeMu sync.Mutex
closeOnce sync.Once
done chan struct{}
}
func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn {
return &realtimeClientConn{
conn: conn,
done: make(chan struct{}),
}
}
func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) {
messageType, p, err = c.conn.ReadMessage()
if err == nil {
c.refreshReadDeadline()
}
return messageType, p, err
}
func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil {
return err
}
if err := c.conn.WriteMessage(messageType, data); err != nil {
return err
}
return c.conn.SetWriteDeadline(time.Time{})
}
func (c *realtimeClientConn) startHeartbeat() {
c.installPongHandler()
c.refreshReadDeadline()
go func() {
ticker := time.NewTicker(realtimeWSPingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.writePing(); err != nil {
_ = c.Close()
return
}
case <-c.done:
return
}
}
}()
}
func (c *realtimeClientConn) stopHeartbeat() {
c.closeDone()
}
func (c *realtimeClientConn) installPongHandler() {
c.conn.SetPongHandler(func(string) error {
return c.refreshReadDeadline()
})
}
func (c *realtimeClientConn) refreshReadDeadline() error {
return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout))
}
func (c *realtimeClientConn) writePing() error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil {
return err
}
if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil {
return err
}
return c.conn.SetWriteDeadline(time.Time{})
}
func (c *realtimeClientConn) closeDone() {
c.closeOnce.Do(func() {
close(c.done)
})
}
func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) {
payload := newRealtimeTurnErrorEventPayload(bifrostErr)
_ = c.WriteMessage(ws.TextMessage, payload)
}
func (c *realtimeClientConn) Close() error {
c.closeDone()
return c.conn.Close()
}
const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key."
// extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol
// header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.<key>".
func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string {
header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol"))
for _, proto := range strings.Split(header, ",") {
proto = strings.TrimSpace(proto)
if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) {
return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix)
}
}
return ""
}
func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError {
errType := code
return &schemas.BifrostError{
StatusCode: &status,
Type: &errType,
Error: &schemas.ErrorField{
Type: &errType,
Code: &errType,
Message: message,
},
}
}