first commit
This commit is contained in:
140
transports/bifrost-http/websocket/connection.go
Normal file
140
transports/bifrost-http/websocket/connection.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Package websocket provides upstream WebSocket connection management for the Bifrost gateway.
|
||||
// It manages pooled connections to provider WebSocket APIs (e.g., OpenAI Responses WS mode,
|
||||
// Realtime API) and client session bindings.
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// UpstreamConn wraps a WebSocket connection to an upstream provider.
|
||||
// Thread-safe for concurrent read/write via separate mutexes.
|
||||
type UpstreamConn struct {
|
||||
conn *ws.Conn
|
||||
provider schemas.ModelProvider
|
||||
keyID string
|
||||
endpoint string
|
||||
createdAt time.Time
|
||||
lastUsed atomic.Int64 // unix nano
|
||||
|
||||
writeMu sync.Mutex
|
||||
readMu sync.Mutex
|
||||
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// newUpstreamConn creates a new UpstreamConn wrapping the given websocket connection.
|
||||
func newUpstreamConn(conn *ws.Conn, provider schemas.ModelProvider, keyID, endpoint string) *UpstreamConn {
|
||||
uc := &UpstreamConn{
|
||||
conn: conn,
|
||||
provider: provider,
|
||||
keyID: keyID,
|
||||
endpoint: endpoint,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
uc.lastUsed.Store(time.Now().UnixNano())
|
||||
return uc
|
||||
}
|
||||
|
||||
// WriteMessage sends a message to the upstream provider. Thread-safe.
|
||||
func (c *UpstreamConn) WriteMessage(messageType int, data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// WriteJSON sends a JSON-encoded message to the upstream provider. Thread-safe.
|
||||
func (c *UpstreamConn) WriteJSON(v interface{}) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
// ReadMessage reads a message from the upstream provider. Thread-safe.
|
||||
func (c *UpstreamConn) ReadMessage() (messageType int, p []byte, err error) {
|
||||
c.readMu.Lock()
|
||||
defer c.readMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.ReadMessage()
|
||||
}
|
||||
|
||||
// Close closes the underlying WebSocket connection.
|
||||
func (c *UpstreamConn) Close() error {
|
||||
if c.closed.CompareAndSwap(false, true) {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connection has been closed.
|
||||
func (c *UpstreamConn) IsClosed() bool {
|
||||
return c.closed.Load()
|
||||
}
|
||||
|
||||
// Provider returns the provider this connection is for.
|
||||
func (c *UpstreamConn) Provider() schemas.ModelProvider {
|
||||
return c.provider
|
||||
}
|
||||
|
||||
// KeyID returns the API key ID used for this connection.
|
||||
func (c *UpstreamConn) KeyID() string {
|
||||
return c.keyID
|
||||
}
|
||||
|
||||
// CreatedAt returns when this connection was established.
|
||||
func (c *UpstreamConn) CreatedAt() time.Time {
|
||||
return c.createdAt
|
||||
}
|
||||
|
||||
// LastUsed returns the last time this connection was used.
|
||||
func (c *UpstreamConn) LastUsed() time.Time {
|
||||
return time.Unix(0, c.lastUsed.Load())
|
||||
}
|
||||
|
||||
// Age returns how long this connection has been alive.
|
||||
func (c *UpstreamConn) Age() time.Duration {
|
||||
return time.Since(c.createdAt)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||
func (c *UpstreamConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||
func (c *UpstreamConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// SetPongHandler sets a handler for pong messages received from the upstream.
|
||||
func (c *UpstreamConn) SetPongHandler(h func(appData string) error) {
|
||||
c.conn.SetPongHandler(h)
|
||||
}
|
||||
|
||||
// WritePing sends a ping message to the upstream. Thread-safe.
|
||||
func (c *UpstreamConn) WritePing(data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.lastUsed.Store(time.Now().UnixNano())
|
||||
return c.conn.WriteMessage(ws.PingMessage, data)
|
||||
}
|
||||
|
||||
// Dial creates a new WebSocket connection to the given URL with the provided headers.
|
||||
func Dial(url string, headers map[string]string) (*ws.Conn, *http.Response, error) {
|
||||
dialer := ws.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
h := http.Header{}
|
||||
for k, v := range headers {
|
||||
h.Set(k, v)
|
||||
}
|
||||
return dialer.Dial(url, h)
|
||||
}
|
||||
8
transports/bifrost-http/websocket/errors.go
Normal file
8
transports/bifrost-http/websocket/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package websocket
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConnectionLimitReached = errors.New("websocket connection limit reached")
|
||||
ErrPoolClosed = errors.New("websocket pool is closed")
|
||||
)
|
||||
221
transports/bifrost-http/websocket/pool.go
Normal file
221
transports/bifrost-http/websocket/pool.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// PoolKey uniquely identifies a group of upstream connections.
|
||||
type PoolKey struct {
|
||||
Provider schemas.ModelProvider
|
||||
KeyID string
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
// Pool manages a pool of upstream WebSocket connections keyed by (provider, keyID, endpoint).
|
||||
// Idle connections are cached for reuse. Connections exceeding max lifetime are discarded.
|
||||
type Pool struct {
|
||||
mu sync.Mutex
|
||||
idle map[PoolKey][]*UpstreamConn
|
||||
inFlight int
|
||||
|
||||
config *schemas.WSPoolConfig
|
||||
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewPool creates a new upstream WebSocket connection pool.
|
||||
func NewPool(config *schemas.WSPoolConfig) *Pool {
|
||||
if config == nil {
|
||||
config = &schemas.WSPoolConfig{}
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
p := &Pool{
|
||||
idle: make(map[PoolKey][]*UpstreamConn),
|
||||
config: config,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go p.evictLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get retrieves an idle connection for the given key, or dials a new one.
|
||||
// The returned connection is removed from the idle pool and must be returned
|
||||
// via Return or discarded via Discard.
|
||||
func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("pool is closed")
|
||||
}
|
||||
|
||||
conns := p.idle[key]
|
||||
for len(conns) > 0 {
|
||||
// Pop from the back (most recently returned)
|
||||
conn := conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
p.idle[key] = conns
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
if conn.IsClosed() || p.isExpired(conn) {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
conns = p.idle[key]
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.inFlight++
|
||||
p.mu.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check total capacity (idle + in-flight) before dialing
|
||||
totalIdle := 0
|
||||
for _, c := range p.idle {
|
||||
totalIdle += len(c)
|
||||
}
|
||||
if totalIdle+p.inFlight >= p.config.MaxTotalConnections {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("pool capacity exhausted: %d idle + %d in-flight >= %d max", totalIdle, p.inFlight, p.config.MaxTotalConnections)
|
||||
}
|
||||
|
||||
// Reserve a slot before unlocking to dial
|
||||
p.inFlight++
|
||||
p.mu.Unlock()
|
||||
|
||||
conn, err := p.dial(key, headers)
|
||||
if err != nil {
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Return puts a connection back into the idle pool for reuse.
|
||||
// If the connection is expired or the pool is full, it is closed instead.
|
||||
func (p *Pool) Return(conn *UpstreamConn) {
|
||||
if conn == nil || conn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if p.isExpired(conn) {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
key := PoolKey{
|
||||
Provider: conn.provider,
|
||||
KeyID: conn.keyID,
|
||||
Endpoint: conn.endpoint,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.inFlight--
|
||||
|
||||
if p.closed {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
conns := p.idle[key]
|
||||
if len(conns) >= p.config.MaxIdlePerKey {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
p.idle[key] = append(conns, conn)
|
||||
}
|
||||
|
||||
// Discard closes a connection without returning it to the pool.
|
||||
func (p *Pool) Discard(conn *UpstreamConn) {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and closes all idle connections.
|
||||
func (p *Pool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.done)
|
||||
|
||||
for key, conns := range p.idle {
|
||||
for _, conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
delete(p.idle, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) dial(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
||||
wsConn, _, err := Dial(key.Endpoint, headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial upstream websocket %s: %w", key.Endpoint, err)
|
||||
}
|
||||
return newUpstreamConn(wsConn, key.Provider, key.KeyID, key.Endpoint), nil
|
||||
}
|
||||
|
||||
func (p *Pool) isExpired(conn *UpstreamConn) bool {
|
||||
maxLifetime := time.Duration(p.config.MaxConnectionLifetimeSeconds) * time.Second
|
||||
if conn.Age() >= maxLifetime {
|
||||
return true
|
||||
}
|
||||
idleTimeout := time.Duration(p.config.IdleTimeoutSeconds) * time.Second
|
||||
return time.Since(conn.LastUsed()) >= idleTimeout
|
||||
}
|
||||
|
||||
// evictLoop periodically removes expired idle connections.
|
||||
func (p *Pool) evictLoop() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.evictExpired()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) evictExpired() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for key, conns := range p.idle {
|
||||
alive := conns[:0]
|
||||
for _, conn := range conns {
|
||||
if conn.IsClosed() || p.isExpired(conn) {
|
||||
conn.Close()
|
||||
} else {
|
||||
alive = append(alive, conn)
|
||||
}
|
||||
}
|
||||
if len(alive) == 0 {
|
||||
delete(p.idle, key)
|
||||
} else {
|
||||
p.idle[key] = alive
|
||||
}
|
||||
}
|
||||
}
|
||||
160
transports/bifrost-http/websocket/pool_test.go
Normal file
160
transports/bifrost-http/websocket/pool_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func startTestWSServer(t *testing.T) *httptest.Server {
|
||||
t.Helper()
|
||||
upgrader := ws.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
conn.WriteMessage(mt, msg)
|
||||
}
|
||||
}))
|
||||
return server
|
||||
}
|
||||
|
||||
func TestPoolGetAndReturn(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 300,
|
||||
MaxConnectionLifetimeSeconds: 3600,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
// Get a new connection (pool is empty, should dial)
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
assert.Equal(t, schemas.OpenAI, conn.Provider())
|
||||
assert.Equal(t, "test-key", conn.KeyID())
|
||||
assert.False(t, conn.IsClosed())
|
||||
|
||||
// Return to pool
|
||||
pool.Return(conn)
|
||||
|
||||
// Get again — should reuse the same connection
|
||||
conn2, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
assert.Same(t, conn, conn2)
|
||||
pool.Return(conn2)
|
||||
}
|
||||
|
||||
func TestPoolMaxIdlePerKey(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 2,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 300,
|
||||
MaxConnectionLifetimeSeconds: 3600,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
// Get 3 connections
|
||||
var conns []*UpstreamConn
|
||||
for range 3 {
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// Return all 3 — only 2 should be kept (MaxIdlePerKey=2)
|
||||
for _, conn := range conns {
|
||||
pool.Return(conn)
|
||||
}
|
||||
|
||||
pool.mu.Lock()
|
||||
idleCount := len(pool.idle[key])
|
||||
pool.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 2, idleCount)
|
||||
}
|
||||
|
||||
func TestPoolClose(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{}
|
||||
config.CheckAndSetDefaults()
|
||||
pool := NewPool(config)
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
pool.Return(conn)
|
||||
|
||||
pool.Close()
|
||||
|
||||
// Getting from a closed pool should fail
|
||||
_, err = pool.Get(key, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPoolExpiredConnection(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 1,
|
||||
MaxConnectionLifetimeSeconds: 1,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
pool.Return(conn)
|
||||
|
||||
// Wait for connection to expire
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Get should dial a new connection (old one expired)
|
||||
conn2, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
assert.NotSame(t, conn, conn2)
|
||||
pool.Discard(conn2)
|
||||
}
|
||||
450
transports/bifrost-http/websocket/session.go
Normal file
450
transports/bifrost-http/websocket/session.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Session tracks the binding between a client WebSocket connection and its upstream state.
|
||||
// For Responses WS mode, it tracks previous_response_id → upstream connection pinning.
|
||||
type Session struct {
|
||||
mu sync.RWMutex
|
||||
writeMu sync.Mutex // serializes all WriteMessage calls to clientConn
|
||||
|
||||
id string
|
||||
|
||||
// Client connection
|
||||
clientConn *ws.Conn
|
||||
|
||||
// Upstream connection currently pinned to this session (for native WS mode).
|
||||
// nil when using HTTP bridge.
|
||||
upstream *UpstreamConn
|
||||
|
||||
// LastResponseID tracks the most recent response ID for previous_response_id chaining.
|
||||
lastResponseID string
|
||||
|
||||
// providerSessionID tracks the upstream provider's session identifier when exposed.
|
||||
providerSessionID string
|
||||
|
||||
// realtimeOutputText accumulates assistant/provider turn text until the terminal event.
|
||||
realtimeOutputText string
|
||||
|
||||
// realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the
|
||||
// completed assistant turn can persist the full turn history instead of only the
|
||||
// latest finalized input event.
|
||||
realtimeTurnInputs []RealtimeTurnInput
|
||||
|
||||
// realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been
|
||||
// attached to a persisted turn, so late transcript updates do not pollute later turns.
|
||||
realtimeConsumedTurnItemIDs map[string]struct{}
|
||||
|
||||
// realtimeTurnHooks tracks the active turn-scoped plugin pipeline between
|
||||
// response.create and response.done.
|
||||
realtimeTurnHooks *RealtimeTurnPluginState
|
||||
realtimeTurnBusy bool
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
type RealtimeToolOutput struct {
|
||||
Summary string
|
||||
Raw string
|
||||
}
|
||||
|
||||
type RealtimeTurnInput struct {
|
||||
ItemID string
|
||||
Role string
|
||||
Summary string
|
||||
Raw string
|
||||
}
|
||||
|
||||
type RealtimeTurnPluginState struct {
|
||||
PostHookRunner schemas.PostHookRunner
|
||||
Cleanup func()
|
||||
RequestID string
|
||||
StartedAt time.Time
|
||||
PreHookValues map[any]any
|
||||
}
|
||||
|
||||
// NewSession creates a new session for a client WebSocket connection.
|
||||
func NewSession(clientConn *ws.Conn) *Session {
|
||||
return &Session{
|
||||
id: uuid.NewString(),
|
||||
clientConn: clientConn,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the stable Bifrost session identifier for this websocket session.
|
||||
func (s *Session) ID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.id
|
||||
}
|
||||
|
||||
// ClientConn returns the client's WebSocket connection.
|
||||
func (s *Session) ClientConn() *ws.Conn {
|
||||
return s.clientConn
|
||||
}
|
||||
|
||||
// WriteMessage sends a message to the client WebSocket connection.
|
||||
// It serializes concurrent writes via writeMu to prevent panics from
|
||||
// simultaneous goroutine writes (e.g., heartbeat vs streaming relay).
|
||||
func (s *Session) WriteMessage(messageType int, data []byte) error {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
return s.clientConn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// SetUpstream pins an upstream connection to this session.
|
||||
func (s *Session) SetUpstream(conn *UpstreamConn) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
if s.upstream != nil && s.upstream != conn {
|
||||
s.upstream.Close()
|
||||
}
|
||||
s.upstream = conn
|
||||
}
|
||||
|
||||
// Upstream returns the currently pinned upstream connection, or nil.
|
||||
func (s *Session) Upstream() *UpstreamConn {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.upstream
|
||||
}
|
||||
|
||||
// SetLastResponseID updates the last response ID for chaining.
|
||||
func (s *Session) SetLastResponseID(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lastResponseID = id
|
||||
}
|
||||
|
||||
// LastResponseID returns the last response ID.
|
||||
func (s *Session) LastResponseID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.lastResponseID
|
||||
}
|
||||
|
||||
// SetProviderSessionID stores the upstream provider session identifier when available.
|
||||
func (s *Session) SetProviderSessionID(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.providerSessionID = id
|
||||
}
|
||||
|
||||
// ProviderSessionID returns the upstream provider session identifier when known.
|
||||
func (s *Session) ProviderSessionID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.providerSessionID
|
||||
}
|
||||
|
||||
// AppendRealtimeOutputText appends provider output content for the current realtime turn.
|
||||
func (s *Session) AppendRealtimeOutputText(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeOutputText += text
|
||||
}
|
||||
|
||||
// ConsumeRealtimeOutputText returns the accumulated provider output and clears it.
|
||||
func (s *Session) ConsumeRealtimeOutputText() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
text := s.realtimeOutputText
|
||||
s.realtimeOutputText = ""
|
||||
return text
|
||||
}
|
||||
|
||||
// AddRealtimeInput stores a finalized user turn event in arrival order.
|
||||
func (s *Session) AddRealtimeInput(summary, raw string) {
|
||||
if summary == "" && raw == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
||||
Role: string(schemas.ChatMessageRoleUser),
|
||||
Summary: summary,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID.
|
||||
// Late updates for items already attached to a completed turn are ignored.
|
||||
func (s *Session) RecordRealtimeInput(itemID, summary, raw string) {
|
||||
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw)
|
||||
}
|
||||
|
||||
// AddRealtimeToolOutput stores a pending tool result for the next assistant turn.
|
||||
func (s *Session) AddRealtimeToolOutput(summary, raw string) {
|
||||
if summary == "" && raw == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
||||
Role: string(schemas.ChatMessageRoleTool),
|
||||
Summary: summary,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID.
|
||||
// Late updates for items already attached to a completed turn are ignored.
|
||||
func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) {
|
||||
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw)
|
||||
}
|
||||
|
||||
func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) {
|
||||
if summary == "" && raw == "" {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
itemID = strings.TrimSpace(itemID)
|
||||
if itemID != "" {
|
||||
if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed {
|
||||
return
|
||||
}
|
||||
for idx := range s.realtimeTurnInputs {
|
||||
if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(summary) != "" {
|
||||
s.realtimeTurnInputs[idx].Summary = summary
|
||||
}
|
||||
if strings.TrimSpace(raw) != "" {
|
||||
existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw)
|
||||
incomingRaw := strings.TrimSpace(raw)
|
||||
switch {
|
||||
case existingRaw == "":
|
||||
s.realtimeTurnInputs[idx].Raw = raw
|
||||
case incomingRaw == "" || existingRaw == incomingRaw:
|
||||
default:
|
||||
s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
||||
ItemID: itemID,
|
||||
Role: role,
|
||||
Summary: summary,
|
||||
Raw: raw,
|
||||
})
|
||||
}
|
||||
|
||||
// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them.
|
||||
func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
|
||||
if len(inputs) > 0 {
|
||||
if s.realtimeConsumedTurnItemIDs == nil {
|
||||
s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs))
|
||||
}
|
||||
for _, input := range inputs {
|
||||
if strings.TrimSpace(input.ItemID) != "" {
|
||||
s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.realtimeTurnInputs = nil
|
||||
return inputs
|
||||
}
|
||||
|
||||
// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them.
|
||||
func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
|
||||
}
|
||||
|
||||
// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline.
|
||||
func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
|
||||
s.realtimeTurnHooks.Cleanup()
|
||||
}
|
||||
s.realtimeTurnBusy = false
|
||||
if s.closed {
|
||||
if state != nil && state.Cleanup != nil {
|
||||
state.Cleanup()
|
||||
}
|
||||
s.realtimeTurnHooks = nil
|
||||
return
|
||||
}
|
||||
s.realtimeTurnHooks = state
|
||||
}
|
||||
|
||||
// TryBeginRealtimeTurnHooks reserves the single active turn slot.
|
||||
func (s *Session) TryBeginRealtimeTurnHooks() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil {
|
||||
return false
|
||||
}
|
||||
s.realtimeTurnBusy = true
|
||||
return true
|
||||
}
|
||||
|
||||
// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks.
|
||||
func (s *Session) AbortRealtimeTurnHooks() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.realtimeTurnBusy = false
|
||||
}
|
||||
|
||||
// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it.
|
||||
func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.realtimeTurnHooks
|
||||
}
|
||||
|
||||
// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it.
|
||||
func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
state := s.realtimeTurnHooks
|
||||
s.realtimeTurnHooks = nil
|
||||
s.realtimeTurnBusy = false
|
||||
return state
|
||||
}
|
||||
|
||||
// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline.
|
||||
func (s *Session) ClearRealtimeTurnHooks() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
|
||||
s.realtimeTurnHooks.Cleanup()
|
||||
}
|
||||
s.realtimeTurnHooks = nil
|
||||
s.realtimeTurnBusy = false
|
||||
}
|
||||
|
||||
// Close closes the session and its upstream connection if pinned.
|
||||
func (s *Session) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
if s.realtimeTurnHooks != nil {
|
||||
if s.realtimeTurnHooks.Cleanup != nil {
|
||||
s.realtimeTurnHooks.Cleanup()
|
||||
}
|
||||
s.realtimeTurnHooks = nil
|
||||
}
|
||||
s.realtimeTurnBusy = false
|
||||
if s.clientConn != nil {
|
||||
_ = s.clientConn.Close()
|
||||
}
|
||||
if s.upstream != nil {
|
||||
s.upstream.Close()
|
||||
s.upstream = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SessionManager tracks active sessions for connection limiting and cleanup.
|
||||
type SessionManager struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[*ws.Conn]*Session
|
||||
maxConns int
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager.
|
||||
func NewSessionManager(maxConns int) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[*ws.Conn]*Session),
|
||||
maxConns: maxConns,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates and registers a new session for the given client connection.
|
||||
// Returns an error if the connection limit would be exceeded.
|
||||
func (m *SessionManager) Create(clientConn *ws.Conn) (*Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.maxConns > 0 && len(m.sessions) >= m.maxConns {
|
||||
return nil, ErrConnectionLimitReached
|
||||
}
|
||||
|
||||
session := NewSession(clientConn)
|
||||
m.sessions[clientConn] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Get returns the session for the given client connection.
|
||||
func (m *SessionManager) Get(clientConn *ws.Conn) *Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.sessions[clientConn]
|
||||
}
|
||||
|
||||
// Remove removes and closes a session.
|
||||
func (m *SessionManager) Remove(clientConn *ws.Conn) {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[clientConn]
|
||||
if ok {
|
||||
delete(m.sessions, clientConn)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if session != nil {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of active sessions.
|
||||
func (m *SessionManager) Count() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
// CloseAll closes all active sessions.
|
||||
func (m *SessionManager) CloseAll() {
|
||||
m.mu.Lock()
|
||||
sessions := m.sessions
|
||||
m.sessions = make(map[*ws.Conn]*Session)
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, session := range sessions {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the currently tracked sessions.
|
||||
func (m *SessionManager) Snapshot() []*Session {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sessions := make([]*Session, 0, len(m.sessions))
|
||||
for _, session := range m.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
156
transports/bifrost-http/websocket/session_test.go
Normal file
156
transports/bifrost-http/websocket/session_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
)
|
||||
|
||||
func TestSessionManagerCreateAndGet(t *testing.T) {
|
||||
manager := NewSessionManager(2)
|
||||
conn := newTestConn()
|
||||
|
||||
session, err := manager.Create(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Create() unexpected error: %v", err)
|
||||
}
|
||||
if session == nil {
|
||||
t.Fatal("Create() returned nil session")
|
||||
}
|
||||
if got := manager.Get(conn); got != session {
|
||||
t.Fatal("Get() did not return the created session")
|
||||
}
|
||||
if got := manager.Count(); got != 1 {
|
||||
t.Fatalf("Count() = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagerConnectionLimit(t *testing.T) {
|
||||
manager := NewSessionManager(1)
|
||||
|
||||
if _, err := manager.Create(newTestConn()); err != nil {
|
||||
t.Fatalf("first Create() unexpected error: %v", err)
|
||||
}
|
||||
if _, err := manager.Create(newTestConn()); err != ErrConnectionLimitReached {
|
||||
t.Fatalf("second Create() error = %v, want %v", err, ErrConnectionLimitReached)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagerRemove(t *testing.T) {
|
||||
manager := NewSessionManager(2)
|
||||
conn := newTestConn()
|
||||
|
||||
session, err := manager.Create(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Create() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
manager.Remove(conn)
|
||||
|
||||
if got := manager.Get(conn); got != nil {
|
||||
t.Fatal("Get() should return nil after Remove()")
|
||||
}
|
||||
if got := manager.Count(); got != 0 {
|
||||
t.Fatalf("Count() = %d, want 0", got)
|
||||
}
|
||||
if !session.closed {
|
||||
t.Fatal("expected removed session to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLastResponseID(t *testing.T) {
|
||||
session := NewSession(newTestConn())
|
||||
session.SetLastResponseID("resp-123")
|
||||
|
||||
if got := session.LastResponseID(); got != "resp-123" {
|
||||
t.Fatalf("LastResponseID() = %q, want %q", got, "resp-123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagerCloseAll(t *testing.T) {
|
||||
manager := NewSessionManager(4)
|
||||
connA := newTestConn()
|
||||
connB := newTestConn()
|
||||
|
||||
sessionA, err := manager.Create(connA)
|
||||
if err != nil {
|
||||
t.Fatalf("Create(connA) unexpected error: %v", err)
|
||||
}
|
||||
sessionB, err := manager.Create(connB)
|
||||
if err != nil {
|
||||
t.Fatalf("Create(connB) unexpected error: %v", err)
|
||||
}
|
||||
|
||||
manager.CloseAll()
|
||||
|
||||
if got := manager.Count(); got != 0 {
|
||||
t.Fatalf("Count() = %d, want 0", got)
|
||||
}
|
||||
if !sessionA.closed || !sessionB.closed {
|
||||
t.Fatal("expected all sessions to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRealtimeState(t *testing.T) {
|
||||
session := NewSession(newTestConn())
|
||||
if session.ID() == "" {
|
||||
t.Fatal("expected session ID to be populated")
|
||||
}
|
||||
|
||||
session.SetProviderSessionID("provider-session-1")
|
||||
if got := session.ProviderSessionID(); got != "provider-session-1" {
|
||||
t.Fatalf("ProviderSessionID() = %q, want %q", got, "provider-session-1")
|
||||
}
|
||||
|
||||
session.AppendRealtimeOutputText("hello")
|
||||
session.AppendRealtimeOutputText(" world")
|
||||
if got := session.ConsumeRealtimeOutputText(); got != "hello world" {
|
||||
t.Fatalf("ConsumeRealtimeOutputText() = %q, want %q", got, "hello world")
|
||||
}
|
||||
if got := session.ConsumeRealtimeOutputText(); got != "" {
|
||||
t.Fatalf("ConsumeRealtimeOutputText() after clear = %q, want empty string", got)
|
||||
}
|
||||
|
||||
session.AddRealtimeInput("hello", `{"type":"conversation.item.create","item":{"role":"user"}}`)
|
||||
session.AddRealtimeToolOutput("tool result", `{"type":"conversation.item.create","item":{"type":"function_call_output"}}`)
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
if len(turnInputs) != 2 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 2", len(turnInputs))
|
||||
}
|
||||
if turnInputs[0].Role != "user" || turnInputs[0].Summary != "hello" {
|
||||
t.Fatalf("turnInputs[0] = %+v, want user hello", turnInputs[0])
|
||||
}
|
||||
if turnInputs[1].Role != "tool" || turnInputs[1].Summary != "tool result" {
|
||||
t.Fatalf("turnInputs[1] = %+v, want tool result", turnInputs[1])
|
||||
}
|
||||
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) after clear = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRecordRealtimeInputUpdatesPendingItemAndIgnoresConsumedLateUpdate(t *testing.T) {
|
||||
session := NewSession(newTestConn())
|
||||
|
||||
session.RecordRealtimeInput("item_1", "[Audio transcription unavailable]", `{"type":"conversation.item.done","item":{"id":"item_1"}}`)
|
||||
session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`)
|
||||
|
||||
turnInputs := session.ConsumeRealtimeTurnInputs()
|
||||
if len(turnInputs) != 1 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 1", len(turnInputs))
|
||||
}
|
||||
if turnInputs[0].ItemID != "item_1" {
|
||||
t.Fatalf("turnInputs[0].ItemID = %q, want %q", turnInputs[0].ItemID, "item_1")
|
||||
}
|
||||
if turnInputs[0].Summary != "Hello." {
|
||||
t.Fatalf("turnInputs[0].Summary = %q, want %q", turnInputs[0].Summary, "Hello.")
|
||||
}
|
||||
|
||||
session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`)
|
||||
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
|
||||
t.Fatalf("len(ConsumeRealtimeTurnInputs()) after late consumed update = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func newTestConn() *ws.Conn {
|
||||
return &ws.Conn{}
|
||||
}
|
||||
Reference in New Issue
Block a user