// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. // This file contains WebSocket handlers for real-time log streaming. package handlers import ( "context" "strings" "sync" "time" "github.com/bytedance/sonic" "github.com/fasthttp/router" "github.com/fasthttp/websocket" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) // WebSocketClient represents a connected WebSocket client with its own mutex type WebSocketClient struct { conn *websocket.Conn mu sync.Mutex // Per-connection mutex for thread-safe writes } // WebSocketHandler manages WebSocket connections for real-time updates type WebSocketHandler struct { ctx context.Context allowedOrigins []string clients map[*websocket.Conn]*WebSocketClient mu sync.RWMutex stopChan chan struct{} // Channel to signal heartbeat goroutine to stop done chan struct{} // Channel to signal when heartbeat goroutine has stopped } // NewWebSocketHandler creates a new WebSocket handler instance func NewWebSocketHandler(ctx context.Context, allowedOrigins []string) *WebSocketHandler { return &WebSocketHandler{ ctx: ctx, allowedOrigins: allowedOrigins, clients: make(map[*websocket.Conn]*WebSocketClient), stopChan: make(chan struct{}), done: make(chan struct{}), } } // RegisterRoutes registers all WebSocket-related routes func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...)) } // getUpgrader returns a WebSocket upgrader configured with the current allowed origins func (h *WebSocketHandler) getUpgrader() websocket.FastHTTPUpgrader { return websocket.FastHTTPUpgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { origin := string(ctx.Request.Header.Peek("Origin")) if origin == "" { // If no Origin header, check the Host header for direct connections host := string(ctx.Request.Header.Peek("Host")) return isLocalhost(host) } // Check if origin is allowed (localhost always allowed + configured origins) return IsOriginAllowed(origin, h.allowedOrigins) }, } } // isLocalhost checks if the given host is localhost func isLocalhost(host string) bool { // Remove port if present if idx := strings.LastIndex(host, ":"); idx != -1 { host = host[:idx] } // Check for localhost variations return host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "" } // connectStream handles WebSocket connections for real-time streaming func (h *WebSocketHandler) connectStream(ctx *fasthttp.RequestCtx) { upgrader := h.getUpgrader() err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) { // Read safety & liveness ws.SetReadLimit(50 << 20) // 50 MiB ws.SetReadDeadline(time.Now().Add(60 * time.Second)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) // Create a new client with its own mutex client := &WebSocketClient{ conn: ws, } // Register new client h.mu.Lock() h.clients[ws] = client h.mu.Unlock() // Clean up on disconnect defer func() { h.mu.Lock() delete(h.clients, ws) h.mu.Unlock() ws.Close() }() // Keep connection alive and handle client messages // This loop continuously reads and discards incoming WebSocket messages to: // 1. Keep the connection alive by processing client pings and control frames // 2. Detect when the client disconnects by watching for close frames or errors // 3. Maintain proper WebSocket protocol handling without accumulating messages for { _, _, err := ws.ReadMessage() if err != nil { // Only log unexpected close errors if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) { logger.Error("websocket read error: %v", err) } break } } }) if err != nil { logger.Error("websocket upgrade error: %v", err) return } } // sendMessageSafely sends a message to a client with proper locking and error handling func (h *WebSocketHandler) sendMessageSafely(client *WebSocketClient, messageType int, data []byte) error { client.mu.Lock() defer client.mu.Unlock() // Set a write deadline to prevent hanging connections client.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) defer client.conn.SetWriteDeadline(time.Time{}) // Clear the deadline err := client.conn.WriteMessage(messageType, data) if err != nil { // Remove the client from the map if write fails go func() { h.mu.Lock() delete(h.clients, client.conn) h.mu.Unlock() client.conn.Close() }() } return err } // BroadcastUpdatesToClients sends a store update notification to all connected WebSocket clients // The tags parameter should match RTK Query tagTypes (e.g., "Providers", "VirtualKeys", "MCPClients") func (h *WebSocketHandler) BroadcastUpdatesToClients(tags []string) { message := struct { Type string `json:"type"` Tags []string `json:"tags"` }{ Type: "store_update", Tags: tags, } data, err := sonic.Marshal(message) if err != nil { logger.Error("failed to marshal store update: %v", err) return } h.BroadcastMarshaledMessage(data) } // BroadcastEvent sends a typed event to all connected WebSocket clients. // Any subsystem can use this to push real-time updates to the frontend. func (h *WebSocketHandler) BroadcastEvent(eventType string, data interface{}) { message := struct { Type string `json:"type"` Data interface{} `json:"data"` }{ Type: eventType, Data: data, } bytes, err := sonic.Marshal(message) if err != nil { logger.Error("failed to marshal event %s: %v", eventType, err) return } h.BroadcastMarshaledMessage(bytes) } // BroadcastMarshaledMessage sends an adaptive routing update to all connected WebSocket clients func (h *WebSocketHandler) BroadcastMarshaledMessage(data []byte) { // Get a snapshot of clients to avoid holding the lock during writes h.mu.RLock() clients := make([]*WebSocketClient, 0, len(h.clients)) for _, client := range h.clients { clients = append(clients, client) } h.mu.RUnlock() // Send message to each client safely for _, client := range clients { if err := h.sendMessageSafely(client, websocket.TextMessage, data); err != nil { logger.Error("failed to send message to client: %v", err) } } } // StartHeartbeat starts sending periodic heartbeat messages to keep connections alive func (h *WebSocketHandler) StartHeartbeat() { ticker := time.NewTicker(30 * time.Second) go func() { defer func() { ticker.Stop() close(h.done) }() for { select { case <-h.ctx.Done(): logger.Info("got context cancel(), stopping webserver") return case <-ticker.C: // Get a snapshot of clients to avoid holding the lock during writes h.mu.RLock() clients := make([]*WebSocketClient, 0, len(h.clients)) for _, client := range h.clients { clients = append(clients, client) } h.mu.RUnlock() // Send heartbeat to each client safely for _, client := range clients { if err := h.sendMessageSafely(client, websocket.PingMessage, nil); err != nil { logger.Error("failed to send heartbeat: %v", err) } } case <-h.stopChan: return } } }() } // Stop gracefully shuts down the WebSocket handler func (h *WebSocketHandler) Stop() { close(h.stopChan) // Signal heartbeat goroutine to stop <-h.done // Wait for heartbeat goroutine to finish // Close all client connections h.mu.Lock() for _, client := range h.clients { client.conn.Close() } h.clients = make(map[*websocket.Conn]*WebSocketClient) h.mu.Unlock() }