Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

141 lines
3.9 KiB
Go

// Package websocket provides upstream WebSocket connection management for the Bifrost gateway.
// It manages pooled connections to provider WebSocket APIs (e.g., OpenAI Responses WS mode,
// Realtime API) and client session bindings.
package websocket
import (
"net/http"
"sync"
"sync/atomic"
"time"
ws "github.com/fasthttp/websocket"
"github.com/maximhq/bifrost/core/schemas"
)
// UpstreamConn wraps a WebSocket connection to an upstream provider.
// Thread-safe for concurrent read/write via separate mutexes.
type UpstreamConn struct {
conn *ws.Conn
provider schemas.ModelProvider
keyID string
endpoint string
createdAt time.Time
lastUsed atomic.Int64 // unix nano
writeMu sync.Mutex
readMu sync.Mutex
closed atomic.Bool
}
// newUpstreamConn creates a new UpstreamConn wrapping the given websocket connection.
func newUpstreamConn(conn *ws.Conn, provider schemas.ModelProvider, keyID, endpoint string) *UpstreamConn {
uc := &UpstreamConn{
conn: conn,
provider: provider,
keyID: keyID,
endpoint: endpoint,
createdAt: time.Now(),
}
uc.lastUsed.Store(time.Now().UnixNano())
return uc
}
// WriteMessage sends a message to the upstream provider. Thread-safe.
func (c *UpstreamConn) WriteMessage(messageType int, data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.WriteMessage(messageType, data)
}
// WriteJSON sends a JSON-encoded message to the upstream provider. Thread-safe.
func (c *UpstreamConn) WriteJSON(v interface{}) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.WriteJSON(v)
}
// ReadMessage reads a message from the upstream provider. Thread-safe.
func (c *UpstreamConn) ReadMessage() (messageType int, p []byte, err error) {
c.readMu.Lock()
defer c.readMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.ReadMessage()
}
// Close closes the underlying WebSocket connection.
func (c *UpstreamConn) Close() error {
if c.closed.CompareAndSwap(false, true) {
return c.conn.Close()
}
return nil
}
// IsClosed returns whether the connection has been closed.
func (c *UpstreamConn) IsClosed() bool {
return c.closed.Load()
}
// Provider returns the provider this connection is for.
func (c *UpstreamConn) Provider() schemas.ModelProvider {
return c.provider
}
// KeyID returns the API key ID used for this connection.
func (c *UpstreamConn) KeyID() string {
return c.keyID
}
// CreatedAt returns when this connection was established.
func (c *UpstreamConn) CreatedAt() time.Time {
return c.createdAt
}
// LastUsed returns the last time this connection was used.
func (c *UpstreamConn) LastUsed() time.Time {
return time.Unix(0, c.lastUsed.Load())
}
// Age returns how long this connection has been alive.
func (c *UpstreamConn) Age() time.Duration {
return time.Since(c.createdAt)
}
// SetReadDeadline sets the read deadline on the underlying connection.
func (c *UpstreamConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
func (c *UpstreamConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// SetPongHandler sets a handler for pong messages received from the upstream.
func (c *UpstreamConn) SetPongHandler(h func(appData string) error) {
c.conn.SetPongHandler(h)
}
// WritePing sends a ping message to the upstream. Thread-safe.
func (c *UpstreamConn) WritePing(data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.WriteMessage(ws.PingMessage, data)
}
// Dial creates a new WebSocket connection to the given URL with the provided headers.
func Dial(url string, headers map[string]string) (*ws.Conn, *http.Response, error) {
dialer := ws.Dialer{
HandshakeTimeout: 10 * time.Second,
}
h := http.Header{}
for k, v := range headers {
h.Set(k, v)
}
return dialer.Dial(url, h)
}