first commit
This commit is contained in:
702
transports/bifrost-http/handlers/wsresponses.go
Normal file
702
transports/bifrost-http/handlers/wsresponses.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/fasthttp/router"
|
||||
ws "github.com/fasthttp/websocket"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/integrations"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// wsWriter abstracts a WebSocket write target. Both *ws.Conn (pre-session)
|
||||
// and *bfws.Session (post-session, mutex-protected) satisfy this interface.
|
||||
type wsWriter interface {
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
}
|
||||
|
||||
// WSResponsesHandler handles WebSocket connections for the Responses API WebSocket Mode.
|
||||
// Clients connect via `GET /v1/responses` with a WS upgrade and send `response.create` events.
|
||||
// Each event is routed through the standard Bifrost inference pipeline (PreLLMHook, key selection,
|
||||
// provider call, PostLLMHook) via the HTTP bridge, with native WS upstream as an optimization.
|
||||
type WSResponsesHandler struct {
|
||||
client *bifrost.Bifrost
|
||||
config *lib.Config
|
||||
handlerStore lib.HandlerStore
|
||||
pool *bfws.Pool
|
||||
sessions *bfws.SessionManager
|
||||
upgrader ws.FastHTTPUpgrader
|
||||
}
|
||||
|
||||
// NewWSResponsesHandler creates a new WebSocket Responses handler.
|
||||
func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSResponsesHandler {
|
||||
maxConns := config.WebSocketConfig.MaxConnections
|
||||
|
||||
return &WSResponsesHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
handlerStore: config,
|
||||
pool: pool,
|
||||
sessions: bfws.NewSessionManager(maxConns),
|
||||
upgrader: ws.FastHTTPUpgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
|
||||
origin := string(ctx.Request.Header.Peek("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
return IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all active WebSocket responses sessions.
|
||||
func (h *WSResponsesHandler) Close() {
|
||||
if h == nil || h.sessions == nil {
|
||||
return
|
||||
}
|
||||
h.sessions.CloseAll()
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the WebSocket Responses endpoint at the base path
|
||||
// and all OpenAI integration paths.
|
||||
func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
|
||||
handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...)
|
||||
// Base path (outside integration prefix)
|
||||
r.GET("/v1/responses", handler)
|
||||
// OpenAI integration paths (/openai/v1/responses, /openai/responses, /openai/openai/responses)
|
||||
for _, path := range integrations.OpenAIWSResponsesPaths("/openai") {
|
||||
r.GET(path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop.
|
||||
func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
|
||||
err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
session, sessionErr := h.sessions.Create(conn)
|
||||
if sessionErr != nil {
|
||||
writeWSError(conn, 429, "websocket_connection_limit_reached", sessionErr.Error())
|
||||
return
|
||||
}
|
||||
defer h.sessions.Remove(conn)
|
||||
|
||||
// Capture auth headers from the upgrade request for per-event context creation
|
||||
authHeaders := captureAuthHeaders(ctx)
|
||||
|
||||
h.eventLoop(conn, session, authHeaders)
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("websocket upgrade failed for /v1/responses: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// authHeaders holds auth-related headers captured during the WS upgrade.
|
||||
type authHeaders struct {
|
||||
authorization string
|
||||
virtualKey string
|
||||
apiKey string
|
||||
googAPIKey string
|
||||
baggage string
|
||||
extraHeaders map[string]string
|
||||
}
|
||||
|
||||
// captureAuthHeaders captures the auth headers from the request.
|
||||
func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders {
|
||||
ah := &authHeaders{
|
||||
authorization: string(ctx.Request.Header.Peek("Authorization")),
|
||||
virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")),
|
||||
apiKey: string(ctx.Request.Header.Peek("x-api-key")),
|
||||
googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")),
|
||||
baggage: string(ctx.Request.Header.Peek("baggage")),
|
||||
extraHeaders: make(map[string]string),
|
||||
}
|
||||
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
k := string(key)
|
||||
lk := strings.ToLower(k)
|
||||
if strings.HasPrefix(lk, "x-bf-") {
|
||||
ah.extraHeaders[k] = string(value)
|
||||
}
|
||||
}
|
||||
return ah
|
||||
}
|
||||
|
||||
// eventLoop reads events from the client WebSocket and processes them.
|
||||
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseNormalClosure) {
|
||||
logger.Warn("websocket read error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the event type
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := sonic.Unmarshal(message, &envelope); err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse event JSON")
|
||||
continue
|
||||
}
|
||||
|
||||
switch schemas.WebSocketEventType(envelope.Type) {
|
||||
case schemas.WSEventResponseCreate:
|
||||
h.handleResponseCreate(session, auth, message)
|
||||
default:
|
||||
writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponseCreate processes a response.create event.
|
||||
// Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge.
|
||||
// If native WS upstream fails mid-stream, falls back to HTTP bridge.
|
||||
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) {
|
||||
var event schemas.WebSocketResponsesEvent
|
||||
|
||||
if err := sonic.Unmarshal(message, &event); err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse response.create event")
|
||||
return
|
||||
}
|
||||
|
||||
// Store override: default to store=true (Codex sends false by default but expects true).
|
||||
// If DisableStore is set in provider config, force store=false.
|
||||
// If client explicitly sets store, respect that value unless DisableStore overrides it.
|
||||
provider, modelName := schemas.ParseModelString(event.Model, "")
|
||||
if provider == "" || modelName == "" {
|
||||
writeWSError(session, 400, "invalid_request_error", "failed to parse model string")
|
||||
return
|
||||
}
|
||||
|
||||
if providerCfg, cfgErr := h.config.GetProviderConfigRaw(provider); cfgErr == nil &&
|
||||
providerCfg.OpenAIConfig != nil && providerCfg.OpenAIConfig.DisableStore {
|
||||
event.Store = schemas.Ptr(false)
|
||||
} else {
|
||||
event.Store = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
bifrostReq, err := h.convertEventToRequest(&event)
|
||||
if err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Extract extra params (unknown fields) and forward them, matching the HTTP path behavior
|
||||
extraParams, extractErr := extractExtraParams(message, wsResponsesKnownFields)
|
||||
if extractErr == nil && len(extraParams) > 0 {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.ResponsesParameters{}
|
||||
}
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth)
|
||||
if bifrostCtx == nil {
|
||||
writeWSError(session, 500, "server_error", "failed to create request context")
|
||||
return
|
||||
}
|
||||
|
||||
// Try native WS upstream first
|
||||
if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Fall back to HTTP bridge
|
||||
h.executeHTTPBridge(session, bifrostCtx, cancel, bifrostReq)
|
||||
}
|
||||
|
||||
// tryNativeWSUpstream attempts to forward the event to a native WS upstream connection.
|
||||
// Returns true if the event was handled (successfully or with error sent to client).
|
||||
// Returns false if the provider doesn't support WS and we should fall back to HTTP bridge.
|
||||
func (h *WSResponsesHandler) tryNativeWSUpstream(
|
||||
session *bfws.Session,
|
||||
ctx *schemas.BifrostContext,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
rawEvent []byte,
|
||||
) bool {
|
||||
provider := h.client.GetProviderByKey(req.Provider)
|
||||
if provider == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
wsProvider, ok := provider.(schemas.WebSocketCapableProvider)
|
||||
if !ok || !wsProvider.SupportsWebSocketMode() {
|
||||
return false
|
||||
}
|
||||
|
||||
key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model)
|
||||
if err != nil {
|
||||
writeWSError(session, 400, "invalid_request_error", err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
wsURL := wsProvider.WebSocketResponsesURL(key)
|
||||
upstream := session.Upstream()
|
||||
|
||||
// Validate the pinned upstream matches the current request's provider/key
|
||||
if upstream != nil && !upstream.IsClosed() &&
|
||||
(upstream.Provider() != req.Provider || upstream.KeyID() != key.ID) {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
upstream = nil
|
||||
}
|
||||
|
||||
// If no upstream connection pinned, get one from the pool or dial
|
||||
if upstream == nil || upstream.IsClosed() {
|
||||
headers := wsProvider.WebSocketHeaders(key)
|
||||
poolKey := bfws.PoolKey{
|
||||
Provider: req.Provider,
|
||||
KeyID: key.ID,
|
||||
Endpoint: wsURL,
|
||||
}
|
||||
|
||||
upstream, err = h.pool.Get(poolKey, headers)
|
||||
if err != nil {
|
||||
logger.Warn("failed to get upstream WS connection for %s: %v, falling back to HTTP bridge", req.Provider, err)
|
||||
return false
|
||||
}
|
||||
session.SetUpstream(upstream)
|
||||
}
|
||||
|
||||
// Run plugin pre-hooks before forwarding to upstream
|
||||
bifrostReq := &schemas.BifrostRequest{
|
||||
RequestType: schemas.WebSocketResponsesRequest,
|
||||
ResponsesRequest: req,
|
||||
}
|
||||
|
||||
hooks, preErr := h.client.RunStreamPreHooks(ctx, bifrostReq)
|
||||
if preErr != nil {
|
||||
writeWSBifrostError(session, preErr)
|
||||
return true
|
||||
}
|
||||
defer hooks.Cleanup()
|
||||
|
||||
// If a plugin short-circuited with a cached response, write it and skip upstream
|
||||
if hooks.ShortCircuitResponse != nil {
|
||||
writeWSShortCircuitResponse(session, hooks.ShortCircuitResponse)
|
||||
return true
|
||||
}
|
||||
|
||||
// Forward the raw event to upstream
|
||||
if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil {
|
||||
logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err)
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
return false
|
||||
}
|
||||
|
||||
// Retrieve tracer and traceID for chunk accumulation
|
||||
tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
|
||||
traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
|
||||
|
||||
// Read response events from upstream and relay to client, running post-hooks per chunk
|
||||
forwardedAny := false
|
||||
for {
|
||||
msgType, data, readErr := upstream.ReadMessage()
|
||||
if readErr != nil {
|
||||
logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr)
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
if !forwardedAny {
|
||||
return false
|
||||
}
|
||||
writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted")
|
||||
return true
|
||||
}
|
||||
|
||||
streamResp := parseUpstreamWSEvent(data, req.Provider, req.Model)
|
||||
isTerminal := streamResp != nil && isTerminalStreamType(streamResp.Type)
|
||||
|
||||
if isTerminal {
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
}
|
||||
|
||||
if streamResp != nil {
|
||||
resp := &schemas.BifrostResponse{ResponsesStreamResponse: streamResp}
|
||||
|
||||
if tracer != nil && traceID != "" {
|
||||
tracer.AddStreamingChunk(traceID, resp)
|
||||
}
|
||||
|
||||
_, postErr := hooks.PostHookRunner(ctx, resp, nil)
|
||||
if postErr != nil {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
writeWSBifrostError(session, postErr)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr := session.WriteMessage(msgType, data); writeErr != nil {
|
||||
h.pool.Discard(upstream)
|
||||
session.SetUpstream(nil)
|
||||
return true
|
||||
}
|
||||
forwardedAny = true
|
||||
|
||||
if isTerminal {
|
||||
h.trackResponseID(session, data)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWSShortCircuitResponse writes a short-circuited plugin response as WS events.
|
||||
func writeWSShortCircuitResponse(session *bfws.Session, resp *schemas.BifrostResponse) {
|
||||
if resp.ResponsesResponse != nil {
|
||||
data, err := sonic.Marshal(resp.ResponsesResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := session.WriteMessage(ws.TextMessage, data); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.ResponsesResponse.ID != nil && *resp.ResponsesResponse.ID != "" {
|
||||
session.SetLastResponseID(*resp.ResponsesResponse.ID)
|
||||
}
|
||||
} else if resp.ResponsesStreamResponse != nil {
|
||||
data, err := sonic.Marshal(resp.ResponsesStreamResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
session.WriteMessage(ws.TextMessage, data)
|
||||
}
|
||||
}
|
||||
|
||||
// parseUpstreamWSEvent attempts to parse a raw upstream WS event into a BifrostResponsesStreamResponse.
|
||||
// It populates ExtraFields so downstream plugins (logging, tracing) can identify the request type.
|
||||
// Returns nil if the data cannot be parsed (non-fatal, the raw bytes are still relayed).
|
||||
func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model string) *schemas.BifrostResponsesStreamResponse {
|
||||
var streamResp schemas.BifrostResponsesStreamResponse
|
||||
if err := sonic.Unmarshal(data, &streamResp); err != nil {
|
||||
return nil
|
||||
}
|
||||
if streamResp.Type == "" {
|
||||
return nil
|
||||
}
|
||||
streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest
|
||||
streamResp.ExtraFields.Provider = provider
|
||||
streamResp.ExtraFields.OriginalModelRequested = model
|
||||
return &streamResp
|
||||
}
|
||||
|
||||
// isTerminalStreamType returns true if the event type signals the end of a response stream.
|
||||
func isTerminalStreamType(t schemas.ResponsesStreamResponseType) bool {
|
||||
switch t {
|
||||
case schemas.ResponsesStreamResponseTypeCompleted,
|
||||
schemas.ResponsesStreamResponseTypeFailed,
|
||||
schemas.ResponsesStreamResponseTypeIncomplete,
|
||||
schemas.ResponsesStreamResponseTypeError:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// trackResponseID extracts and stores the response ID from terminal events.
|
||||
func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte) {
|
||||
var envelope struct {
|
||||
Response struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"response"`
|
||||
}
|
||||
if err := sonic.Unmarshal(data, &envelope); err == nil && envelope.Response.ID != "" {
|
||||
session.SetLastResponseID(envelope.Response.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest.
|
||||
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) {
|
||||
provider, modelName := schemas.ParseModelString(event.Model, "")
|
||||
if provider == "" || modelName == "" {
|
||||
return nil, errModelFormat
|
||||
}
|
||||
|
||||
var input []schemas.ResponsesMessage
|
||||
if event.Input != nil {
|
||||
// Try parsing as array first
|
||||
if err := sonic.Unmarshal(event.Input, &input); err != nil {
|
||||
// Try as string
|
||||
var inputStr string
|
||||
if strErr := sonic.Unmarshal(event.Input, &inputStr); strErr != nil {
|
||||
return nil, errInputRequired
|
||||
}
|
||||
input = []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: &inputStr},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(input) == 0 {
|
||||
return nil, errInputRequired
|
||||
}
|
||||
|
||||
params := &schemas.ResponsesParameters{}
|
||||
if event.Temperature != nil {
|
||||
params.Temperature = event.Temperature
|
||||
}
|
||||
if event.TopP != nil {
|
||||
params.TopP = event.TopP
|
||||
}
|
||||
if event.MaxOutputTokens != nil {
|
||||
params.MaxOutputTokens = event.MaxOutputTokens
|
||||
}
|
||||
if event.Instructions != "" {
|
||||
params.Instructions = &event.Instructions
|
||||
}
|
||||
if event.PreviousResponseID != "" {
|
||||
params.PreviousResponseID = &event.PreviousResponseID
|
||||
}
|
||||
if event.Store != nil {
|
||||
params.Store = event.Store
|
||||
}
|
||||
if event.Tools != nil {
|
||||
var tools []schemas.ResponsesTool
|
||||
if err := sonic.Unmarshal(event.Tools, &tools); err == nil {
|
||||
params.Tools = tools
|
||||
}
|
||||
}
|
||||
if event.ToolChoice != nil {
|
||||
var tc schemas.ResponsesToolChoice
|
||||
if err := sonic.Unmarshal(event.ToolChoice, &tc); err == nil {
|
||||
params.ToolChoice = &tc
|
||||
}
|
||||
}
|
||||
if event.Reasoning != nil {
|
||||
var reasoning schemas.ResponsesParametersReasoning
|
||||
if err := sonic.Unmarshal(event.Reasoning, &reasoning); err == nil {
|
||||
params.Reasoning = &reasoning
|
||||
}
|
||||
}
|
||||
if event.Text != nil {
|
||||
var text schemas.ResponsesTextConfig
|
||||
if err := sonic.Unmarshal(event.Text, &text); err == nil {
|
||||
params.Text = &text
|
||||
}
|
||||
}
|
||||
if event.Metadata != nil {
|
||||
var metadata map[string]any
|
||||
if err := sonic.Unmarshal(event.Metadata, &metadata); err == nil {
|
||||
params.Metadata = &metadata
|
||||
}
|
||||
}
|
||||
if event.Truncation != "" {
|
||||
params.Truncation = &event.Truncation
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.ModelProvider(provider),
|
||||
Model: modelName,
|
||||
Input: input,
|
||||
Params: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade.
|
||||
func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background())
|
||||
|
||||
if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID)
|
||||
}
|
||||
|
||||
if auth.virtualKey != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey)
|
||||
}
|
||||
|
||||
// Handle Bearer token with sk-bf- prefix (virtual key via Authorization header)
|
||||
if auth.authorization != "" {
|
||||
if strings.HasPrefix(auth.authorization, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth.authorization, "Bearer ")
|
||||
if strings.HasPrefix(token, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(token),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth.apiKey != "" {
|
||||
if strings.HasPrefix(auth.apiKey, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(auth.apiKey),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
if auth.googAPIKey != "" {
|
||||
if strings.HasPrefix(auth.googAPIKey, "sk-bf-") {
|
||||
ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-"))
|
||||
} else if handlerStore.ShouldAllowDirectKeys() {
|
||||
key := schemas.Key{
|
||||
ID: "header-provided",
|
||||
Value: *schemas.NewEnvVar(auth.googAPIKey),
|
||||
Models: schemas.WhiteList{"*"},
|
||||
Weight: 1.0,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDirectKey, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward x-bf-* headers
|
||||
for k, v := range auth.extraHeaders {
|
||||
lk := strings.ToLower(k)
|
||||
switch {
|
||||
case lk == "x-bf-vk":
|
||||
// Already handled above
|
||||
case lk == "x-bf-api-key":
|
||||
ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v)
|
||||
case strings.HasPrefix(lk, "x-bf-eh-"):
|
||||
suffix := strings.TrimPrefix(lk, "x-bf-eh-")
|
||||
existing, _ := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string]string)
|
||||
if existing == nil {
|
||||
existing = make(map[string]string)
|
||||
}
|
||||
existing[suffix] = v
|
||||
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// executeHTTPBridge runs the response through the existing streaming inference pipeline.
|
||||
func (h *WSResponsesHandler) executeHTTPBridge(
|
||||
session *bfws.Session,
|
||||
ctx *schemas.BifrostContext,
|
||||
cancel context.CancelFunc,
|
||||
req *schemas.BifrostResponsesRequest,
|
||||
) {
|
||||
defer cancel()
|
||||
|
||||
stream, bifrostErr := h.client.ResponsesStreamRequest(ctx, req)
|
||||
if bifrostErr != nil {
|
||||
writeWSBifrostError(session, bifrostErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Relay streaming chunks as WS messages
|
||||
for chunk := range stream {
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkJSON, err := sonic.Marshal(chunk)
|
||||
if err != nil {
|
||||
logger.Warn("failed to marshal stream chunk: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if writeErr := session.WriteMessage(ws.TextMessage, chunkJSON); writeErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Track last response ID for session chaining
|
||||
if chunk.BifrostResponsesStreamResponse != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Response != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Response.ID != nil &&
|
||||
*chunk.BifrostResponsesStreamResponse.Response.ID != "" {
|
||||
session.SetLastResponseID(*chunk.BifrostResponsesStreamResponse.Response.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWSError sends a JSON error event to a WebSocket write target.
|
||||
// Accepts either a raw *ws.Conn (pre-session) or a *bfws.Session (mutex-protected).
|
||||
func writeWSError(w wsWriter, status int, code, message string) {
|
||||
event := schemas.WebSocketErrorEvent{
|
||||
Type: schemas.WSEventError,
|
||||
Status: status,
|
||||
Error: &schemas.WebSocketErrorBody{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
data, err := sonic.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.WriteMessage(ws.TextMessage, data)
|
||||
}
|
||||
|
||||
// writeWSBifrostError converts a BifrostError to a WS error event.
|
||||
func writeWSBifrostError(w wsWriter, bifrostErr *schemas.BifrostError) {
|
||||
status := 500
|
||||
if bifrostErr.StatusCode != nil && *bifrostErr.StatusCode > 0 {
|
||||
status = *bifrostErr.StatusCode
|
||||
}
|
||||
code := "server_error"
|
||||
msg := "internal server error"
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Code != nil && *bifrostErr.Error.Code != "" {
|
||||
code = *bifrostErr.Error.Code
|
||||
} else if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
|
||||
code = *bifrostErr.Error.Type
|
||||
}
|
||||
if bifrostErr.Error.Message != "" {
|
||||
msg = bifrostErr.Error.Message
|
||||
}
|
||||
}
|
||||
writeWSError(w, status, code, msg)
|
||||
}
|
||||
|
||||
// wsResponsesKnownFields lists the fields explicitly handled by WebSocketResponsesEvent.
|
||||
// Anything not in this set is treated as an extra param and forwarded as-is to the provider.
|
||||
var wsResponsesKnownFields = map[string]bool{
|
||||
"type": true,
|
||||
"model": true,
|
||||
"store": true,
|
||||
"input": true,
|
||||
"instructions": true,
|
||||
"previous_response_id": true,
|
||||
"tools": true,
|
||||
"tool_choice": true,
|
||||
"temperature": true,
|
||||
"top_p": true,
|
||||
"max_output_tokens": true,
|
||||
"reasoning": true,
|
||||
"metadata": true,
|
||||
"text": true,
|
||||
"truncation": true,
|
||||
}
|
||||
|
||||
var (
|
||||
errModelFormat = errorf("model should be in provider/model format")
|
||||
errInputRequired = errorf("input is required for responses")
|
||||
)
|
||||
|
||||
func errorf(msg string) error {
|
||||
return &simpleError{msg: msg}
|
||||
}
|
||||
|
||||
type simpleError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *simpleError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
Reference in New Issue
Block a user