first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,140 @@
// Package websocket provides upstream WebSocket connection management for the Bifrost gateway.
// It manages pooled connections to provider WebSocket APIs (e.g., OpenAI Responses WS mode,
// Realtime API) and client session bindings.
package websocket
import (
"net/http"
"sync"
"sync/atomic"
"time"
ws "github.com/fasthttp/websocket"
"github.com/maximhq/bifrost/core/schemas"
)
// UpstreamConn wraps a WebSocket connection to an upstream provider.
// Thread-safe for concurrent read/write via separate mutexes.
type UpstreamConn struct {
conn *ws.Conn
provider schemas.ModelProvider
keyID string
endpoint string
createdAt time.Time
lastUsed atomic.Int64 // unix nano
writeMu sync.Mutex
readMu sync.Mutex
closed atomic.Bool
}
// newUpstreamConn creates a new UpstreamConn wrapping the given websocket connection.
func newUpstreamConn(conn *ws.Conn, provider schemas.ModelProvider, keyID, endpoint string) *UpstreamConn {
uc := &UpstreamConn{
conn: conn,
provider: provider,
keyID: keyID,
endpoint: endpoint,
createdAt: time.Now(),
}
uc.lastUsed.Store(time.Now().UnixNano())
return uc
}
// WriteMessage sends a message to the upstream provider. Thread-safe.
func (c *UpstreamConn) WriteMessage(messageType int, data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.WriteMessage(messageType, data)
}
// WriteJSON sends a JSON-encoded message to the upstream provider. Thread-safe.
func (c *UpstreamConn) WriteJSON(v interface{}) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.WriteJSON(v)
}
// ReadMessage reads a message from the upstream provider. Thread-safe.
func (c *UpstreamConn) ReadMessage() (messageType int, p []byte, err error) {
c.readMu.Lock()
defer c.readMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.ReadMessage()
}
// Close closes the underlying WebSocket connection.
func (c *UpstreamConn) Close() error {
if c.closed.CompareAndSwap(false, true) {
return c.conn.Close()
}
return nil
}
// IsClosed returns whether the connection has been closed.
func (c *UpstreamConn) IsClosed() bool {
return c.closed.Load()
}
// Provider returns the provider this connection is for.
func (c *UpstreamConn) Provider() schemas.ModelProvider {
return c.provider
}
// KeyID returns the API key ID used for this connection.
func (c *UpstreamConn) KeyID() string {
return c.keyID
}
// CreatedAt returns when this connection was established.
func (c *UpstreamConn) CreatedAt() time.Time {
return c.createdAt
}
// LastUsed returns the last time this connection was used.
func (c *UpstreamConn) LastUsed() time.Time {
return time.Unix(0, c.lastUsed.Load())
}
// Age returns how long this connection has been alive.
func (c *UpstreamConn) Age() time.Duration {
return time.Since(c.createdAt)
}
// SetReadDeadline sets the read deadline on the underlying connection.
func (c *UpstreamConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
func (c *UpstreamConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// SetPongHandler sets a handler for pong messages received from the upstream.
func (c *UpstreamConn) SetPongHandler(h func(appData string) error) {
c.conn.SetPongHandler(h)
}
// WritePing sends a ping message to the upstream. Thread-safe.
func (c *UpstreamConn) WritePing(data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.lastUsed.Store(time.Now().UnixNano())
return c.conn.WriteMessage(ws.PingMessage, data)
}
// Dial creates a new WebSocket connection to the given URL with the provided headers.
func Dial(url string, headers map[string]string) (*ws.Conn, *http.Response, error) {
dialer := ws.Dialer{
HandshakeTimeout: 10 * time.Second,
}
h := http.Header{}
for k, v := range headers {
h.Set(k, v)
}
return dialer.Dial(url, h)
}

View File

@@ -0,0 +1,8 @@
package websocket
import "errors"
var (
ErrConnectionLimitReached = errors.New("websocket connection limit reached")
ErrPoolClosed = errors.New("websocket pool is closed")
)

View File

@@ -0,0 +1,221 @@
package websocket
import (
"fmt"
"sync"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// PoolKey uniquely identifies a group of upstream connections.
type PoolKey struct {
Provider schemas.ModelProvider
KeyID string
Endpoint string
}
// Pool manages a pool of upstream WebSocket connections keyed by (provider, keyID, endpoint).
// Idle connections are cached for reuse. Connections exceeding max lifetime are discarded.
type Pool struct {
mu sync.Mutex
idle map[PoolKey][]*UpstreamConn
inFlight int
config *schemas.WSPoolConfig
closed bool
done chan struct{}
}
// NewPool creates a new upstream WebSocket connection pool.
func NewPool(config *schemas.WSPoolConfig) *Pool {
if config == nil {
config = &schemas.WSPoolConfig{}
}
config.CheckAndSetDefaults()
p := &Pool{
idle: make(map[PoolKey][]*UpstreamConn),
config: config,
done: make(chan struct{}),
}
go p.evictLoop()
return p
}
// Get retrieves an idle connection for the given key, or dials a new one.
// The returned connection is removed from the idle pool and must be returned
// via Return or discarded via Discard.
func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return nil, fmt.Errorf("pool is closed")
}
conns := p.idle[key]
for len(conns) > 0 {
// Pop from the back (most recently returned)
conn := conns[len(conns)-1]
conns = conns[:len(conns)-1]
p.idle[key] = conns
p.mu.Unlock()
if conn.IsClosed() || p.isExpired(conn) {
conn.Close()
p.mu.Lock()
conns = p.idle[key]
continue
}
p.mu.Lock()
p.inFlight++
p.mu.Unlock()
return conn, nil
}
// Check total capacity (idle + in-flight) before dialing
totalIdle := 0
for _, c := range p.idle {
totalIdle += len(c)
}
if totalIdle+p.inFlight >= p.config.MaxTotalConnections {
p.mu.Unlock()
return nil, fmt.Errorf("pool capacity exhausted: %d idle + %d in-flight >= %d max", totalIdle, p.inFlight, p.config.MaxTotalConnections)
}
// Reserve a slot before unlocking to dial
p.inFlight++
p.mu.Unlock()
conn, err := p.dial(key, headers)
if err != nil {
p.mu.Lock()
p.inFlight--
p.mu.Unlock()
return nil, err
}
return conn, nil
}
// Return puts a connection back into the idle pool for reuse.
// If the connection is expired or the pool is full, it is closed instead.
func (p *Pool) Return(conn *UpstreamConn) {
if conn == nil || conn.IsClosed() {
return
}
if p.isExpired(conn) {
conn.Close()
p.mu.Lock()
p.inFlight--
p.mu.Unlock()
return
}
key := PoolKey{
Provider: conn.provider,
KeyID: conn.keyID,
Endpoint: conn.endpoint,
}
p.mu.Lock()
defer p.mu.Unlock()
p.inFlight--
if p.closed {
conn.Close()
return
}
conns := p.idle[key]
if len(conns) >= p.config.MaxIdlePerKey {
conn.Close()
return
}
p.idle[key] = append(conns, conn)
}
// Discard closes a connection without returning it to the pool.
func (p *Pool) Discard(conn *UpstreamConn) {
if conn != nil {
conn.Close()
p.mu.Lock()
p.inFlight--
p.mu.Unlock()
}
}
// Close shuts down the pool and closes all idle connections.
func (p *Pool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
p.closed = true
close(p.done)
for key, conns := range p.idle {
for _, conn := range conns {
conn.Close()
}
delete(p.idle, key)
}
}
func (p *Pool) dial(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
wsConn, _, err := Dial(key.Endpoint, headers)
if err != nil {
return nil, fmt.Errorf("failed to dial upstream websocket %s: %w", key.Endpoint, err)
}
return newUpstreamConn(wsConn, key.Provider, key.KeyID, key.Endpoint), nil
}
func (p *Pool) isExpired(conn *UpstreamConn) bool {
maxLifetime := time.Duration(p.config.MaxConnectionLifetimeSeconds) * time.Second
if conn.Age() >= maxLifetime {
return true
}
idleTimeout := time.Duration(p.config.IdleTimeoutSeconds) * time.Second
return time.Since(conn.LastUsed()) >= idleTimeout
}
// evictLoop periodically removes expired idle connections.
func (p *Pool) evictLoop() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-p.done:
return
case <-ticker.C:
p.evictExpired()
}
}
}
func (p *Pool) evictExpired() {
p.mu.Lock()
defer p.mu.Unlock()
for key, conns := range p.idle {
alive := conns[:0]
for _, conn := range conns {
if conn.IsClosed() || p.isExpired(conn) {
conn.Close()
} else {
alive = append(alive, conn)
}
}
if len(alive) == 0 {
delete(p.idle, key)
} else {
p.idle[key] = alive
}
}
}

View File

@@ -0,0 +1,160 @@
package websocket
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
ws "github.com/fasthttp/websocket"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func startTestWSServer(t *testing.T) *httptest.Server {
t.Helper()
upgrader := ws.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, msg, err := conn.ReadMessage()
if err != nil {
break
}
conn.WriteMessage(mt, msg)
}
}))
return server
}
func TestPoolGetAndReturn(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{
MaxIdlePerKey: 5,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 300,
MaxConnectionLifetimeSeconds: 3600,
}
pool := NewPool(config)
defer pool.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
// Get a new connection (pool is empty, should dial)
conn, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn)
assert.Equal(t, schemas.OpenAI, conn.Provider())
assert.Equal(t, "test-key", conn.KeyID())
assert.False(t, conn.IsClosed())
// Return to pool
pool.Return(conn)
// Get again — should reuse the same connection
conn2, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn2)
assert.Same(t, conn, conn2)
pool.Return(conn2)
}
func TestPoolMaxIdlePerKey(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{
MaxIdlePerKey: 2,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 300,
MaxConnectionLifetimeSeconds: 3600,
}
pool := NewPool(config)
defer pool.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
// Get 3 connections
var conns []*UpstreamConn
for range 3 {
conn, err := pool.Get(key, nil)
require.NoError(t, err)
conns = append(conns, conn)
}
// Return all 3 — only 2 should be kept (MaxIdlePerKey=2)
for _, conn := range conns {
pool.Return(conn)
}
pool.mu.Lock()
idleCount := len(pool.idle[key])
pool.mu.Unlock()
assert.Equal(t, 2, idleCount)
}
func TestPoolClose(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{}
config.CheckAndSetDefaults()
pool := NewPool(config)
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
conn, err := pool.Get(key, nil)
require.NoError(t, err)
pool.Return(conn)
pool.Close()
// Getting from a closed pool should fail
_, err = pool.Get(key, nil)
assert.Error(t, err)
}
func TestPoolExpiredConnection(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{
MaxIdlePerKey: 5,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 1,
MaxConnectionLifetimeSeconds: 1,
}
pool := NewPool(config)
defer pool.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
conn, err := pool.Get(key, nil)
require.NoError(t, err)
pool.Return(conn)
// Wait for connection to expire
time.Sleep(1500 * time.Millisecond)
// Get should dial a new connection (old one expired)
conn2, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn2)
assert.NotSame(t, conn, conn2)
pool.Discard(conn2)
}

View File

@@ -0,0 +1,450 @@
package websocket
import (
"strings"
"sync"
"time"
ws "github.com/fasthttp/websocket"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
)
// Session tracks the binding between a client WebSocket connection and its upstream state.
// For Responses WS mode, it tracks previous_response_id → upstream connection pinning.
type Session struct {
mu sync.RWMutex
writeMu sync.Mutex // serializes all WriteMessage calls to clientConn
id string
// Client connection
clientConn *ws.Conn
// Upstream connection currently pinned to this session (for native WS mode).
// nil when using HTTP bridge.
upstream *UpstreamConn
// LastResponseID tracks the most recent response ID for previous_response_id chaining.
lastResponseID string
// providerSessionID tracks the upstream provider's session identifier when exposed.
providerSessionID string
// realtimeOutputText accumulates assistant/provider turn text until the terminal event.
realtimeOutputText string
// realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the
// completed assistant turn can persist the full turn history instead of only the
// latest finalized input event.
realtimeTurnInputs []RealtimeTurnInput
// realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been
// attached to a persisted turn, so late transcript updates do not pollute later turns.
realtimeConsumedTurnItemIDs map[string]struct{}
// realtimeTurnHooks tracks the active turn-scoped plugin pipeline between
// response.create and response.done.
realtimeTurnHooks *RealtimeTurnPluginState
realtimeTurnBusy bool
closed bool
}
type RealtimeToolOutput struct {
Summary string
Raw string
}
type RealtimeTurnInput struct {
ItemID string
Role string
Summary string
Raw string
}
type RealtimeTurnPluginState struct {
PostHookRunner schemas.PostHookRunner
Cleanup func()
RequestID string
StartedAt time.Time
PreHookValues map[any]any
}
// NewSession creates a new session for a client WebSocket connection.
func NewSession(clientConn *ws.Conn) *Session {
return &Session{
id: uuid.NewString(),
clientConn: clientConn,
}
}
// ID returns the stable Bifrost session identifier for this websocket session.
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
// ClientConn returns the client's WebSocket connection.
func (s *Session) ClientConn() *ws.Conn {
return s.clientConn
}
// WriteMessage sends a message to the client WebSocket connection.
// It serializes concurrent writes via writeMu to prevent panics from
// simultaneous goroutine writes (e.g., heartbeat vs streaming relay).
func (s *Session) WriteMessage(messageType int, data []byte) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
return s.clientConn.WriteMessage(messageType, data)
}
// SetUpstream pins an upstream connection to this session.
func (s *Session) SetUpstream(conn *UpstreamConn) {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
if conn != nil {
conn.Close()
}
return
}
if s.upstream != nil && s.upstream != conn {
s.upstream.Close()
}
s.upstream = conn
}
// Upstream returns the currently pinned upstream connection, or nil.
func (s *Session) Upstream() *UpstreamConn {
s.mu.RLock()
defer s.mu.RUnlock()
return s.upstream
}
// SetLastResponseID updates the last response ID for chaining.
func (s *Session) SetLastResponseID(id string) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastResponseID = id
}
// LastResponseID returns the last response ID.
func (s *Session) LastResponseID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.lastResponseID
}
// SetProviderSessionID stores the upstream provider session identifier when available.
func (s *Session) SetProviderSessionID(id string) {
s.mu.Lock()
defer s.mu.Unlock()
s.providerSessionID = id
}
// ProviderSessionID returns the upstream provider session identifier when known.
func (s *Session) ProviderSessionID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.providerSessionID
}
// AppendRealtimeOutputText appends provider output content for the current realtime turn.
func (s *Session) AppendRealtimeOutputText(text string) {
if text == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeOutputText += text
}
// ConsumeRealtimeOutputText returns the accumulated provider output and clears it.
func (s *Session) ConsumeRealtimeOutputText() string {
s.mu.Lock()
defer s.mu.Unlock()
text := s.realtimeOutputText
s.realtimeOutputText = ""
return text
}
// AddRealtimeInput stores a finalized user turn event in arrival order.
func (s *Session) AddRealtimeInput(summary, raw string) {
if summary == "" && raw == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
Role: string(schemas.ChatMessageRoleUser),
Summary: summary,
Raw: raw,
})
}
// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID.
// Late updates for items already attached to a completed turn are ignored.
func (s *Session) RecordRealtimeInput(itemID, summary, raw string) {
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw)
}
// AddRealtimeToolOutput stores a pending tool result for the next assistant turn.
func (s *Session) AddRealtimeToolOutput(summary, raw string) {
if summary == "" && raw == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
Role: string(schemas.ChatMessageRoleTool),
Summary: summary,
Raw: raw,
})
}
// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID.
// Late updates for items already attached to a completed turn are ignored.
func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) {
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw)
}
func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) {
if summary == "" && raw == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
itemID = strings.TrimSpace(itemID)
if itemID != "" {
if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed {
return
}
for idx := range s.realtimeTurnInputs {
if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role {
continue
}
if strings.TrimSpace(summary) != "" {
s.realtimeTurnInputs[idx].Summary = summary
}
if strings.TrimSpace(raw) != "" {
existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw)
incomingRaw := strings.TrimSpace(raw)
switch {
case existingRaw == "":
s.realtimeTurnInputs[idx].Raw = raw
case incomingRaw == "" || existingRaw == incomingRaw:
default:
s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw
}
}
return
}
}
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
ItemID: itemID,
Role: role,
Summary: summary,
Raw: raw,
})
}
// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them.
func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput {
s.mu.Lock()
defer s.mu.Unlock()
inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
if len(inputs) > 0 {
if s.realtimeConsumedTurnItemIDs == nil {
s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs))
}
for _, input := range inputs {
if strings.TrimSpace(input.ItemID) != "" {
s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{}
}
}
}
s.realtimeTurnInputs = nil
return inputs
}
// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them.
func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput {
s.mu.RLock()
defer s.mu.RUnlock()
return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
}
// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline.
func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) {
s.mu.Lock()
defer s.mu.Unlock()
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
s.realtimeTurnHooks.Cleanup()
}
s.realtimeTurnBusy = false
if s.closed {
if state != nil && state.Cleanup != nil {
state.Cleanup()
}
s.realtimeTurnHooks = nil
return
}
s.realtimeTurnHooks = state
}
// TryBeginRealtimeTurnHooks reserves the single active turn slot.
func (s *Session) TryBeginRealtimeTurnHooks() bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil {
return false
}
s.realtimeTurnBusy = true
return true
}
// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks.
func (s *Session) AbortRealtimeTurnHooks() {
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeTurnBusy = false
}
// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it.
func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState {
s.mu.RLock()
defer s.mu.RUnlock()
return s.realtimeTurnHooks
}
// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it.
func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState {
s.mu.Lock()
defer s.mu.Unlock()
state := s.realtimeTurnHooks
s.realtimeTurnHooks = nil
s.realtimeTurnBusy = false
return state
}
// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline.
func (s *Session) ClearRealtimeTurnHooks() {
s.mu.Lock()
defer s.mu.Unlock()
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
s.realtimeTurnHooks.Cleanup()
}
s.realtimeTurnHooks = nil
s.realtimeTurnBusy = false
}
// Close closes the session and its upstream connection if pinned.
func (s *Session) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return
}
s.closed = true
if s.realtimeTurnHooks != nil {
if s.realtimeTurnHooks.Cleanup != nil {
s.realtimeTurnHooks.Cleanup()
}
s.realtimeTurnHooks = nil
}
s.realtimeTurnBusy = false
if s.clientConn != nil {
_ = s.clientConn.Close()
}
if s.upstream != nil {
s.upstream.Close()
s.upstream = nil
}
}
// SessionManager tracks active sessions for connection limiting and cleanup.
type SessionManager struct {
mu sync.RWMutex
sessions map[*ws.Conn]*Session
maxConns int
}
// NewSessionManager creates a new session manager.
func NewSessionManager(maxConns int) *SessionManager {
return &SessionManager{
sessions: make(map[*ws.Conn]*Session),
maxConns: maxConns,
}
}
// Create creates and registers a new session for the given client connection.
// Returns an error if the connection limit would be exceeded.
func (m *SessionManager) Create(clientConn *ws.Conn) (*Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.maxConns > 0 && len(m.sessions) >= m.maxConns {
return nil, ErrConnectionLimitReached
}
session := NewSession(clientConn)
m.sessions[clientConn] = session
return session, nil
}
// Get returns the session for the given client connection.
func (m *SessionManager) Get(clientConn *ws.Conn) *Session {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessions[clientConn]
}
// Remove removes and closes a session.
func (m *SessionManager) Remove(clientConn *ws.Conn) {
m.mu.Lock()
session, ok := m.sessions[clientConn]
if ok {
delete(m.sessions, clientConn)
}
m.mu.Unlock()
if session != nil {
session.Close()
}
}
// Count returns the number of active sessions.
func (m *SessionManager) Count() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.sessions)
}
// CloseAll closes all active sessions.
func (m *SessionManager) CloseAll() {
m.mu.Lock()
sessions := m.sessions
m.sessions = make(map[*ws.Conn]*Session)
m.mu.Unlock()
for _, session := range sessions {
session.Close()
}
}
// Snapshot returns a copy of the currently tracked sessions.
func (m *SessionManager) Snapshot() []*Session {
m.mu.RLock()
defer m.mu.RUnlock()
sessions := make([]*Session, 0, len(m.sessions))
for _, session := range m.sessions {
sessions = append(sessions, session)
}
return sessions
}

View File

@@ -0,0 +1,156 @@
package websocket
import (
"testing"
ws "github.com/fasthttp/websocket"
)
func TestSessionManagerCreateAndGet(t *testing.T) {
manager := NewSessionManager(2)
conn := newTestConn()
session, err := manager.Create(conn)
if err != nil {
t.Fatalf("Create() unexpected error: %v", err)
}
if session == nil {
t.Fatal("Create() returned nil session")
}
if got := manager.Get(conn); got != session {
t.Fatal("Get() did not return the created session")
}
if got := manager.Count(); got != 1 {
t.Fatalf("Count() = %d, want 1", got)
}
}
func TestSessionManagerConnectionLimit(t *testing.T) {
manager := NewSessionManager(1)
if _, err := manager.Create(newTestConn()); err != nil {
t.Fatalf("first Create() unexpected error: %v", err)
}
if _, err := manager.Create(newTestConn()); err != ErrConnectionLimitReached {
t.Fatalf("second Create() error = %v, want %v", err, ErrConnectionLimitReached)
}
}
func TestSessionManagerRemove(t *testing.T) {
manager := NewSessionManager(2)
conn := newTestConn()
session, err := manager.Create(conn)
if err != nil {
t.Fatalf("Create() unexpected error: %v", err)
}
manager.Remove(conn)
if got := manager.Get(conn); got != nil {
t.Fatal("Get() should return nil after Remove()")
}
if got := manager.Count(); got != 0 {
t.Fatalf("Count() = %d, want 0", got)
}
if !session.closed {
t.Fatal("expected removed session to be closed")
}
}
func TestSessionLastResponseID(t *testing.T) {
session := NewSession(newTestConn())
session.SetLastResponseID("resp-123")
if got := session.LastResponseID(); got != "resp-123" {
t.Fatalf("LastResponseID() = %q, want %q", got, "resp-123")
}
}
func TestSessionManagerCloseAll(t *testing.T) {
manager := NewSessionManager(4)
connA := newTestConn()
connB := newTestConn()
sessionA, err := manager.Create(connA)
if err != nil {
t.Fatalf("Create(connA) unexpected error: %v", err)
}
sessionB, err := manager.Create(connB)
if err != nil {
t.Fatalf("Create(connB) unexpected error: %v", err)
}
manager.CloseAll()
if got := manager.Count(); got != 0 {
t.Fatalf("Count() = %d, want 0", got)
}
if !sessionA.closed || !sessionB.closed {
t.Fatal("expected all sessions to be closed")
}
}
func TestSessionRealtimeState(t *testing.T) {
session := NewSession(newTestConn())
if session.ID() == "" {
t.Fatal("expected session ID to be populated")
}
session.SetProviderSessionID("provider-session-1")
if got := session.ProviderSessionID(); got != "provider-session-1" {
t.Fatalf("ProviderSessionID() = %q, want %q", got, "provider-session-1")
}
session.AppendRealtimeOutputText("hello")
session.AppendRealtimeOutputText(" world")
if got := session.ConsumeRealtimeOutputText(); got != "hello world" {
t.Fatalf("ConsumeRealtimeOutputText() = %q, want %q", got, "hello world")
}
if got := session.ConsumeRealtimeOutputText(); got != "" {
t.Fatalf("ConsumeRealtimeOutputText() after clear = %q, want empty string", got)
}
session.AddRealtimeInput("hello", `{"type":"conversation.item.create","item":{"role":"user"}}`)
session.AddRealtimeToolOutput("tool result", `{"type":"conversation.item.create","item":{"type":"function_call_output"}}`)
turnInputs := session.ConsumeRealtimeTurnInputs()
if len(turnInputs) != 2 {
t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 2", len(turnInputs))
}
if turnInputs[0].Role != "user" || turnInputs[0].Summary != "hello" {
t.Fatalf("turnInputs[0] = %+v, want user hello", turnInputs[0])
}
if turnInputs[1].Role != "tool" || turnInputs[1].Summary != "tool result" {
t.Fatalf("turnInputs[1] = %+v, want tool result", turnInputs[1])
}
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
t.Fatalf("len(ConsumeRealtimeTurnInputs()) after clear = %d, want 0", len(got))
}
}
func TestSessionRecordRealtimeInputUpdatesPendingItemAndIgnoresConsumedLateUpdate(t *testing.T) {
session := NewSession(newTestConn())
session.RecordRealtimeInput("item_1", "[Audio transcription unavailable]", `{"type":"conversation.item.done","item":{"id":"item_1"}}`)
session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`)
turnInputs := session.ConsumeRealtimeTurnInputs()
if len(turnInputs) != 1 {
t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 1", len(turnInputs))
}
if turnInputs[0].ItemID != "item_1" {
t.Fatalf("turnInputs[0].ItemID = %q, want %q", turnInputs[0].ItemID, "item_1")
}
if turnInputs[0].Summary != "Hello." {
t.Fatalf("turnInputs[0].Summary = %q, want %q", turnInputs[0].Summary, "Hello.")
}
session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`)
if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 {
t.Fatalf("len(ConsumeRealtimeTurnInputs()) after late consumed update = %d, want 0", len(got))
}
}
func newTestConn() *ws.Conn {
return &ws.Conn{}
}