451 lines
12 KiB
Go
451 lines
12 KiB
Go
package websocket
|
|
|
|
import (
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
ws "github.com/fasthttp/websocket"
|
|
"github.com/google/uuid"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// Session tracks the binding between a client WebSocket connection and its upstream state.
|
|
// For Responses WS mode, it tracks previous_response_id → upstream connection pinning.
|
|
type Session struct {
|
|
mu sync.RWMutex
|
|
writeMu sync.Mutex // serializes all WriteMessage calls to clientConn
|
|
|
|
id string
|
|
|
|
// Client connection
|
|
clientConn *ws.Conn
|
|
|
|
// Upstream connection currently pinned to this session (for native WS mode).
|
|
// nil when using HTTP bridge.
|
|
upstream *UpstreamConn
|
|
|
|
// LastResponseID tracks the most recent response ID for previous_response_id chaining.
|
|
lastResponseID string
|
|
|
|
// providerSessionID tracks the upstream provider's session identifier when exposed.
|
|
providerSessionID string
|
|
|
|
// realtimeOutputText accumulates assistant/provider turn text until the terminal event.
|
|
realtimeOutputText string
|
|
|
|
// realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the
|
|
// completed assistant turn can persist the full turn history instead of only the
|
|
// latest finalized input event.
|
|
realtimeTurnInputs []RealtimeTurnInput
|
|
|
|
// realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been
|
|
// attached to a persisted turn, so late transcript updates do not pollute later turns.
|
|
realtimeConsumedTurnItemIDs map[string]struct{}
|
|
|
|
// realtimeTurnHooks tracks the active turn-scoped plugin pipeline between
|
|
// response.create and response.done.
|
|
realtimeTurnHooks *RealtimeTurnPluginState
|
|
realtimeTurnBusy bool
|
|
|
|
closed bool
|
|
}
|
|
|
|
type RealtimeToolOutput struct {
|
|
Summary string
|
|
Raw string
|
|
}
|
|
|
|
type RealtimeTurnInput struct {
|
|
ItemID string
|
|
Role string
|
|
Summary string
|
|
Raw string
|
|
}
|
|
|
|
type RealtimeTurnPluginState struct {
|
|
PostHookRunner schemas.PostHookRunner
|
|
Cleanup func()
|
|
RequestID string
|
|
StartedAt time.Time
|
|
PreHookValues map[any]any
|
|
}
|
|
|
|
// NewSession creates a new session for a client WebSocket connection.
|
|
func NewSession(clientConn *ws.Conn) *Session {
|
|
return &Session{
|
|
id: uuid.NewString(),
|
|
clientConn: clientConn,
|
|
}
|
|
}
|
|
|
|
// ID returns the stable Bifrost session identifier for this websocket session.
|
|
func (s *Session) ID() string {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.id
|
|
}
|
|
|
|
// ClientConn returns the client's WebSocket connection.
|
|
func (s *Session) ClientConn() *ws.Conn {
|
|
return s.clientConn
|
|
}
|
|
|
|
// WriteMessage sends a message to the client WebSocket connection.
|
|
// It serializes concurrent writes via writeMu to prevent panics from
|
|
// simultaneous goroutine writes (e.g., heartbeat vs streaming relay).
|
|
func (s *Session) WriteMessage(messageType int, data []byte) error {
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
return s.clientConn.WriteMessage(messageType, data)
|
|
}
|
|
|
|
// SetUpstream pins an upstream connection to this session.
|
|
func (s *Session) SetUpstream(conn *UpstreamConn) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.closed {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
return
|
|
}
|
|
if s.upstream != nil && s.upstream != conn {
|
|
s.upstream.Close()
|
|
}
|
|
s.upstream = conn
|
|
}
|
|
|
|
// Upstream returns the currently pinned upstream connection, or nil.
|
|
func (s *Session) Upstream() *UpstreamConn {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.upstream
|
|
}
|
|
|
|
// SetLastResponseID updates the last response ID for chaining.
|
|
func (s *Session) SetLastResponseID(id string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.lastResponseID = id
|
|
}
|
|
|
|
// LastResponseID returns the last response ID.
|
|
func (s *Session) LastResponseID() string {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.lastResponseID
|
|
}
|
|
|
|
// SetProviderSessionID stores the upstream provider session identifier when available.
|
|
func (s *Session) SetProviderSessionID(id string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.providerSessionID = id
|
|
}
|
|
|
|
// ProviderSessionID returns the upstream provider session identifier when known.
|
|
func (s *Session) ProviderSessionID() string {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.providerSessionID
|
|
}
|
|
|
|
// AppendRealtimeOutputText appends provider output content for the current realtime turn.
|
|
func (s *Session) AppendRealtimeOutputText(text string) {
|
|
if text == "" {
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.realtimeOutputText += text
|
|
}
|
|
|
|
// ConsumeRealtimeOutputText returns the accumulated provider output and clears it.
|
|
func (s *Session) ConsumeRealtimeOutputText() string {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
text := s.realtimeOutputText
|
|
s.realtimeOutputText = ""
|
|
return text
|
|
}
|
|
|
|
// AddRealtimeInput stores a finalized user turn event in arrival order.
|
|
func (s *Session) AddRealtimeInput(summary, raw string) {
|
|
if summary == "" && raw == "" {
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
|
Role: string(schemas.ChatMessageRoleUser),
|
|
Summary: summary,
|
|
Raw: raw,
|
|
})
|
|
}
|
|
|
|
// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID.
|
|
// Late updates for items already attached to a completed turn are ignored.
|
|
func (s *Session) RecordRealtimeInput(itemID, summary, raw string) {
|
|
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw)
|
|
}
|
|
|
|
// AddRealtimeToolOutput stores a pending tool result for the next assistant turn.
|
|
func (s *Session) AddRealtimeToolOutput(summary, raw string) {
|
|
if summary == "" && raw == "" {
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
|
Role: string(schemas.ChatMessageRoleTool),
|
|
Summary: summary,
|
|
Raw: raw,
|
|
})
|
|
}
|
|
|
|
// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID.
|
|
// Late updates for items already attached to a completed turn are ignored.
|
|
func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) {
|
|
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw)
|
|
}
|
|
|
|
func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) {
|
|
if summary == "" && raw == "" {
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
itemID = strings.TrimSpace(itemID)
|
|
if itemID != "" {
|
|
if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed {
|
|
return
|
|
}
|
|
for idx := range s.realtimeTurnInputs {
|
|
if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(summary) != "" {
|
|
s.realtimeTurnInputs[idx].Summary = summary
|
|
}
|
|
if strings.TrimSpace(raw) != "" {
|
|
existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw)
|
|
incomingRaw := strings.TrimSpace(raw)
|
|
switch {
|
|
case existingRaw == "":
|
|
s.realtimeTurnInputs[idx].Raw = raw
|
|
case incomingRaw == "" || existingRaw == incomingRaw:
|
|
default:
|
|
s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
|
|
ItemID: itemID,
|
|
Role: role,
|
|
Summary: summary,
|
|
Raw: raw,
|
|
})
|
|
}
|
|
|
|
// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them.
|
|
func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
|
|
if len(inputs) > 0 {
|
|
if s.realtimeConsumedTurnItemIDs == nil {
|
|
s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs))
|
|
}
|
|
for _, input := range inputs {
|
|
if strings.TrimSpace(input.ItemID) != "" {
|
|
s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
s.realtimeTurnInputs = nil
|
|
return inputs
|
|
}
|
|
|
|
// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them.
|
|
func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
|
|
}
|
|
|
|
// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline.
|
|
func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
|
|
s.realtimeTurnHooks.Cleanup()
|
|
}
|
|
s.realtimeTurnBusy = false
|
|
if s.closed {
|
|
if state != nil && state.Cleanup != nil {
|
|
state.Cleanup()
|
|
}
|
|
s.realtimeTurnHooks = nil
|
|
return
|
|
}
|
|
s.realtimeTurnHooks = state
|
|
}
|
|
|
|
// TryBeginRealtimeTurnHooks reserves the single active turn slot.
|
|
func (s *Session) TryBeginRealtimeTurnHooks() bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil {
|
|
return false
|
|
}
|
|
s.realtimeTurnBusy = true
|
|
return true
|
|
}
|
|
|
|
// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks.
|
|
func (s *Session) AbortRealtimeTurnHooks() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.realtimeTurnBusy = false
|
|
}
|
|
|
|
// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it.
|
|
func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.realtimeTurnHooks
|
|
}
|
|
|
|
// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it.
|
|
func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
state := s.realtimeTurnHooks
|
|
s.realtimeTurnHooks = nil
|
|
s.realtimeTurnBusy = false
|
|
return state
|
|
}
|
|
|
|
// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline.
|
|
func (s *Session) ClearRealtimeTurnHooks() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
|
|
s.realtimeTurnHooks.Cleanup()
|
|
}
|
|
s.realtimeTurnHooks = nil
|
|
s.realtimeTurnBusy = false
|
|
}
|
|
|
|
// Close closes the session and its upstream connection if pinned.
|
|
func (s *Session) Close() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.closed {
|
|
return
|
|
}
|
|
s.closed = true
|
|
if s.realtimeTurnHooks != nil {
|
|
if s.realtimeTurnHooks.Cleanup != nil {
|
|
s.realtimeTurnHooks.Cleanup()
|
|
}
|
|
s.realtimeTurnHooks = nil
|
|
}
|
|
s.realtimeTurnBusy = false
|
|
if s.clientConn != nil {
|
|
_ = s.clientConn.Close()
|
|
}
|
|
if s.upstream != nil {
|
|
s.upstream.Close()
|
|
s.upstream = nil
|
|
}
|
|
}
|
|
|
|
// SessionManager tracks active sessions for connection limiting and cleanup.
|
|
type SessionManager struct {
|
|
mu sync.RWMutex
|
|
sessions map[*ws.Conn]*Session
|
|
maxConns int
|
|
}
|
|
|
|
// NewSessionManager creates a new session manager.
|
|
func NewSessionManager(maxConns int) *SessionManager {
|
|
return &SessionManager{
|
|
sessions: make(map[*ws.Conn]*Session),
|
|
maxConns: maxConns,
|
|
}
|
|
}
|
|
|
|
// Create creates and registers a new session for the given client connection.
|
|
// Returns an error if the connection limit would be exceeded.
|
|
func (m *SessionManager) Create(clientConn *ws.Conn) (*Session, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.maxConns > 0 && len(m.sessions) >= m.maxConns {
|
|
return nil, ErrConnectionLimitReached
|
|
}
|
|
|
|
session := NewSession(clientConn)
|
|
m.sessions[clientConn] = session
|
|
return session, nil
|
|
}
|
|
|
|
// Get returns the session for the given client connection.
|
|
func (m *SessionManager) Get(clientConn *ws.Conn) *Session {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
return m.sessions[clientConn]
|
|
}
|
|
|
|
// Remove removes and closes a session.
|
|
func (m *SessionManager) Remove(clientConn *ws.Conn) {
|
|
m.mu.Lock()
|
|
session, ok := m.sessions[clientConn]
|
|
if ok {
|
|
delete(m.sessions, clientConn)
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
if session != nil {
|
|
session.Close()
|
|
}
|
|
}
|
|
|
|
// Count returns the number of active sessions.
|
|
func (m *SessionManager) Count() int {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
return len(m.sessions)
|
|
}
|
|
|
|
// CloseAll closes all active sessions.
|
|
func (m *SessionManager) CloseAll() {
|
|
m.mu.Lock()
|
|
sessions := m.sessions
|
|
m.sessions = make(map[*ws.Conn]*Session)
|
|
m.mu.Unlock()
|
|
|
|
for _, session := range sessions {
|
|
session.Close()
|
|
}
|
|
}
|
|
|
|
// Snapshot returns a copy of the currently tracked sessions.
|
|
func (m *SessionManager) Snapshot() []*Session {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
sessions := make([]*Session, 0, len(m.sessions))
|
|
for _, session := range m.sessions {
|
|
sessions = append(sessions, session)
|
|
}
|
|
return sessions
|
|
}
|