Files
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

451 lines
12 KiB
Go

package websocket
import (
"strings"
"sync"
"time"
ws "github.com/fasthttp/websocket"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
)
// Session tracks the binding between a client WebSocket connection and its upstream state.
// For Responses WS mode, it tracks previous_response_id → upstream connection pinning.
type Session struct {
mu sync.RWMutex
writeMu sync.Mutex // serializes all WriteMessage calls to clientConn
id string
// Client connection
clientConn *ws.Conn
// Upstream connection currently pinned to this session (for native WS mode).
// nil when using HTTP bridge.
upstream *UpstreamConn
// LastResponseID tracks the most recent response ID for previous_response_id chaining.
lastResponseID string
// providerSessionID tracks the upstream provider's session identifier when exposed.
providerSessionID string
// realtimeOutputText accumulates assistant/provider turn text until the terminal event.
realtimeOutputText string
// realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the
// completed assistant turn can persist the full turn history instead of only the
// latest finalized input event.
realtimeTurnInputs []RealtimeTurnInput
// realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been
// attached to a persisted turn, so late transcript updates do not pollute later turns.
realtimeConsumedTurnItemIDs map[string]struct{}
// realtimeTurnHooks tracks the active turn-scoped plugin pipeline between
// response.create and response.done.
realtimeTurnHooks *RealtimeTurnPluginState
realtimeTurnBusy bool
closed bool
}
type RealtimeToolOutput struct {
Summary string
Raw string
}
type RealtimeTurnInput struct {
ItemID string
Role string
Summary string
Raw string
}
type RealtimeTurnPluginState struct {
PostHookRunner schemas.PostHookRunner
Cleanup func()
RequestID string
StartedAt time.Time
PreHookValues map[any]any
}
// NewSession creates a new session for a client WebSocket connection.
func NewSession(clientConn *ws.Conn) *Session {
return &Session{
id: uuid.NewString(),
clientConn: clientConn,
}
}
// ID returns the stable Bifrost session identifier for this websocket session.
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
// ClientConn returns the client's WebSocket connection.
func (s *Session) ClientConn() *ws.Conn {
return s.clientConn
}
// WriteMessage sends a message to the client WebSocket connection.
// It serializes concurrent writes via writeMu to prevent panics from
// simultaneous goroutine writes (e.g., heartbeat vs streaming relay).
func (s *Session) WriteMessage(messageType int, data []byte) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
return s.clientConn.WriteMessage(messageType, data)
}
// SetUpstream pins an upstream connection to this session.
func (s *Session) SetUpstream(conn *UpstreamConn) {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
if conn != nil {
conn.Close()
}
return
}
if s.upstream != nil && s.upstream != conn {
s.upstream.Close()
}
s.upstream = conn
}
// Upstream returns the currently pinned upstream connection, or nil.
func (s *Session) Upstream() *UpstreamConn {
s.mu.RLock()
defer s.mu.RUnlock()
return s.upstream
}
// SetLastResponseID updates the last response ID for chaining.
func (s *Session) SetLastResponseID(id string) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastResponseID = id
}
// LastResponseID returns the last response ID.
func (s *Session) LastResponseID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.lastResponseID
}
// SetProviderSessionID stores the upstream provider session identifier when available.
func (s *Session) SetProviderSessionID(id string) {
s.mu.Lock()
defer s.mu.Unlock()
s.providerSessionID = id
}
// ProviderSessionID returns the upstream provider session identifier when known.
func (s *Session) ProviderSessionID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.providerSessionID
}
// AppendRealtimeOutputText appends provider output content for the current realtime turn.
func (s *Session) AppendRealtimeOutputText(text string) {
if text == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeOutputText += text
}
// ConsumeRealtimeOutputText returns the accumulated provider output and clears it.
func (s *Session) ConsumeRealtimeOutputText() string {
s.mu.Lock()
defer s.mu.Unlock()
text := s.realtimeOutputText
s.realtimeOutputText = ""
return text
}
// AddRealtimeInput stores a finalized user turn event in arrival order.
func (s *Session) AddRealtimeInput(summary, raw string) {
if summary == "" && raw == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
Role: string(schemas.ChatMessageRoleUser),
Summary: summary,
Raw: raw,
})
}
// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID.
// Late updates for items already attached to a completed turn are ignored.
func (s *Session) RecordRealtimeInput(itemID, summary, raw string) {
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw)
}
// AddRealtimeToolOutput stores a pending tool result for the next assistant turn.
func (s *Session) AddRealtimeToolOutput(summary, raw string) {
if summary == "" && raw == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
Role: string(schemas.ChatMessageRoleTool),
Summary: summary,
Raw: raw,
})
}
// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID.
// Late updates for items already attached to a completed turn are ignored.
func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) {
s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw)
}
func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) {
if summary == "" && raw == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
itemID = strings.TrimSpace(itemID)
if itemID != "" {
if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed {
return
}
for idx := range s.realtimeTurnInputs {
if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role {
continue
}
if strings.TrimSpace(summary) != "" {
s.realtimeTurnInputs[idx].Summary = summary
}
if strings.TrimSpace(raw) != "" {
existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw)
incomingRaw := strings.TrimSpace(raw)
switch {
case existingRaw == "":
s.realtimeTurnInputs[idx].Raw = raw
case incomingRaw == "" || existingRaw == incomingRaw:
default:
s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw
}
}
return
}
}
s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{
ItemID: itemID,
Role: role,
Summary: summary,
Raw: raw,
})
}
// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them.
func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput {
s.mu.Lock()
defer s.mu.Unlock()
inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
if len(inputs) > 0 {
if s.realtimeConsumedTurnItemIDs == nil {
s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs))
}
for _, input := range inputs {
if strings.TrimSpace(input.ItemID) != "" {
s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{}
}
}
}
s.realtimeTurnInputs = nil
return inputs
}
// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them.
func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput {
s.mu.RLock()
defer s.mu.RUnlock()
return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...)
}
// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline.
func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) {
s.mu.Lock()
defer s.mu.Unlock()
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
s.realtimeTurnHooks.Cleanup()
}
s.realtimeTurnBusy = false
if s.closed {
if state != nil && state.Cleanup != nil {
state.Cleanup()
}
s.realtimeTurnHooks = nil
return
}
s.realtimeTurnHooks = state
}
// TryBeginRealtimeTurnHooks reserves the single active turn slot.
func (s *Session) TryBeginRealtimeTurnHooks() bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil {
return false
}
s.realtimeTurnBusy = true
return true
}
// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks.
func (s *Session) AbortRealtimeTurnHooks() {
s.mu.Lock()
defer s.mu.Unlock()
s.realtimeTurnBusy = false
}
// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it.
func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState {
s.mu.RLock()
defer s.mu.RUnlock()
return s.realtimeTurnHooks
}
// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it.
func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState {
s.mu.Lock()
defer s.mu.Unlock()
state := s.realtimeTurnHooks
s.realtimeTurnHooks = nil
s.realtimeTurnBusy = false
return state
}
// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline.
func (s *Session) ClearRealtimeTurnHooks() {
s.mu.Lock()
defer s.mu.Unlock()
if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil {
s.realtimeTurnHooks.Cleanup()
}
s.realtimeTurnHooks = nil
s.realtimeTurnBusy = false
}
// Close closes the session and its upstream connection if pinned.
func (s *Session) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return
}
s.closed = true
if s.realtimeTurnHooks != nil {
if s.realtimeTurnHooks.Cleanup != nil {
s.realtimeTurnHooks.Cleanup()
}
s.realtimeTurnHooks = nil
}
s.realtimeTurnBusy = false
if s.clientConn != nil {
_ = s.clientConn.Close()
}
if s.upstream != nil {
s.upstream.Close()
s.upstream = nil
}
}
// SessionManager tracks active sessions for connection limiting and cleanup.
type SessionManager struct {
mu sync.RWMutex
sessions map[*ws.Conn]*Session
maxConns int
}
// NewSessionManager creates a new session manager.
func NewSessionManager(maxConns int) *SessionManager {
return &SessionManager{
sessions: make(map[*ws.Conn]*Session),
maxConns: maxConns,
}
}
// Create creates and registers a new session for the given client connection.
// Returns an error if the connection limit would be exceeded.
func (m *SessionManager) Create(clientConn *ws.Conn) (*Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.maxConns > 0 && len(m.sessions) >= m.maxConns {
return nil, ErrConnectionLimitReached
}
session := NewSession(clientConn)
m.sessions[clientConn] = session
return session, nil
}
// Get returns the session for the given client connection.
func (m *SessionManager) Get(clientConn *ws.Conn) *Session {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessions[clientConn]
}
// Remove removes and closes a session.
func (m *SessionManager) Remove(clientConn *ws.Conn) {
m.mu.Lock()
session, ok := m.sessions[clientConn]
if ok {
delete(m.sessions, clientConn)
}
m.mu.Unlock()
if session != nil {
session.Close()
}
}
// Count returns the number of active sessions.
func (m *SessionManager) Count() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.sessions)
}
// CloseAll closes all active sessions.
func (m *SessionManager) CloseAll() {
m.mu.Lock()
sessions := m.sessions
m.sessions = make(map[*ws.Conn]*Session)
m.mu.Unlock()
for _, session := range sessions {
session.Close()
}
}
// Snapshot returns a copy of the currently tracked sessions.
func (m *SessionManager) Snapshot() []*Session {
m.mu.RLock()
defer m.mu.RUnlock()
sessions := make([]*Session, 0, len(m.sessions))
for _, session := range m.sessions {
sessions = append(sessions, session)
}
return sessions
}