141 lines
3.9 KiB
Go
141 lines
3.9 KiB
Go
// Package websocket provides upstream WebSocket connection management for the Bifrost gateway.
|
|
// It manages pooled connections to provider WebSocket APIs (e.g., OpenAI Responses WS mode,
|
|
// Realtime API) and client session bindings.
|
|
package websocket
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
ws "github.com/fasthttp/websocket"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// UpstreamConn wraps a WebSocket connection to an upstream provider.
|
|
// Thread-safe for concurrent read/write via separate mutexes.
|
|
type UpstreamConn struct {
|
|
conn *ws.Conn
|
|
provider schemas.ModelProvider
|
|
keyID string
|
|
endpoint string
|
|
createdAt time.Time
|
|
lastUsed atomic.Int64 // unix nano
|
|
|
|
writeMu sync.Mutex
|
|
readMu sync.Mutex
|
|
|
|
closed atomic.Bool
|
|
}
|
|
|
|
// newUpstreamConn creates a new UpstreamConn wrapping the given websocket connection.
|
|
func newUpstreamConn(conn *ws.Conn, provider schemas.ModelProvider, keyID, endpoint string) *UpstreamConn {
|
|
uc := &UpstreamConn{
|
|
conn: conn,
|
|
provider: provider,
|
|
keyID: keyID,
|
|
endpoint: endpoint,
|
|
createdAt: time.Now(),
|
|
}
|
|
uc.lastUsed.Store(time.Now().UnixNano())
|
|
return uc
|
|
}
|
|
|
|
// WriteMessage sends a message to the upstream provider. Thread-safe.
|
|
func (c *UpstreamConn) WriteMessage(messageType int, data []byte) error {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
c.lastUsed.Store(time.Now().UnixNano())
|
|
return c.conn.WriteMessage(messageType, data)
|
|
}
|
|
|
|
// WriteJSON sends a JSON-encoded message to the upstream provider. Thread-safe.
|
|
func (c *UpstreamConn) WriteJSON(v interface{}) error {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
c.lastUsed.Store(time.Now().UnixNano())
|
|
return c.conn.WriteJSON(v)
|
|
}
|
|
|
|
// ReadMessage reads a message from the upstream provider. Thread-safe.
|
|
func (c *UpstreamConn) ReadMessage() (messageType int, p []byte, err error) {
|
|
c.readMu.Lock()
|
|
defer c.readMu.Unlock()
|
|
c.lastUsed.Store(time.Now().UnixNano())
|
|
return c.conn.ReadMessage()
|
|
}
|
|
|
|
// Close closes the underlying WebSocket connection.
|
|
func (c *UpstreamConn) Close() error {
|
|
if c.closed.CompareAndSwap(false, true) {
|
|
return c.conn.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// IsClosed returns whether the connection has been closed.
|
|
func (c *UpstreamConn) IsClosed() bool {
|
|
return c.closed.Load()
|
|
}
|
|
|
|
// Provider returns the provider this connection is for.
|
|
func (c *UpstreamConn) Provider() schemas.ModelProvider {
|
|
return c.provider
|
|
}
|
|
|
|
// KeyID returns the API key ID used for this connection.
|
|
func (c *UpstreamConn) KeyID() string {
|
|
return c.keyID
|
|
}
|
|
|
|
// CreatedAt returns when this connection was established.
|
|
func (c *UpstreamConn) CreatedAt() time.Time {
|
|
return c.createdAt
|
|
}
|
|
|
|
// LastUsed returns the last time this connection was used.
|
|
func (c *UpstreamConn) LastUsed() time.Time {
|
|
return time.Unix(0, c.lastUsed.Load())
|
|
}
|
|
|
|
// Age returns how long this connection has been alive.
|
|
func (c *UpstreamConn) Age() time.Duration {
|
|
return time.Since(c.createdAt)
|
|
}
|
|
|
|
// SetReadDeadline sets the read deadline on the underlying connection.
|
|
func (c *UpstreamConn) SetReadDeadline(t time.Time) error {
|
|
return c.conn.SetReadDeadline(t)
|
|
}
|
|
|
|
// SetWriteDeadline sets the write deadline on the underlying connection.
|
|
func (c *UpstreamConn) SetWriteDeadline(t time.Time) error {
|
|
return c.conn.SetWriteDeadline(t)
|
|
}
|
|
|
|
// SetPongHandler sets a handler for pong messages received from the upstream.
|
|
func (c *UpstreamConn) SetPongHandler(h func(appData string) error) {
|
|
c.conn.SetPongHandler(h)
|
|
}
|
|
|
|
// WritePing sends a ping message to the upstream. Thread-safe.
|
|
func (c *UpstreamConn) WritePing(data []byte) error {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
c.lastUsed.Store(time.Now().UnixNano())
|
|
return c.conn.WriteMessage(ws.PingMessage, data)
|
|
}
|
|
|
|
// Dial creates a new WebSocket connection to the given URL with the provided headers.
|
|
func Dial(url string, headers map[string]string) (*ws.Conn, *http.Response, error) {
|
|
dialer := ws.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
}
|
|
h := http.Header{}
|
|
for k, v := range headers {
|
|
h.Set(k, v)
|
|
}
|
|
return dialer.Dial(url, h)
|
|
}
|