package realtime import ( "context" "log/slog" "net/http" "strings" "sync" "github.com/golang-jwt/jwt/v5" "github.com/gorilla/websocket" "github.com/multica-ai/multica/server/internal/auth" ) // MembershipChecker verifies a user belongs to a workspace. type MembershipChecker interface { IsMember(ctx context.Context, userID, workspaceID string) bool } var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { // TODO: Restrict origins in production return true }, } // Client represents a single WebSocket connection with identity. type Client struct { hub *Hub conn *websocket.Conn send chan []byte userID string workspaceID string } // Hub manages WebSocket connections organized by workspace rooms. type Hub struct { rooms map[string]map[*Client]bool // workspaceID -> clients broadcast chan []byte // global broadcast (daemon events) register chan *Client unregister chan *Client mu sync.RWMutex } // NewHub creates a new Hub instance. func NewHub() *Hub { return &Hub{ rooms: make(map[string]map[*Client]bool), broadcast: make(chan []byte), register: make(chan *Client), unregister: make(chan *Client), } } // Run starts the hub event loop. func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() room := client.workspaceID if h.rooms[room] == nil { h.rooms[room] = make(map[*Client]bool) } h.rooms[room][client] = true total := 0 for _, r := range h.rooms { total += len(r) } h.mu.Unlock() slog.Info("ws client connected", "workspace_id", room, "total_clients", total) case client := <-h.unregister: h.mu.Lock() room := client.workspaceID if clients, ok := h.rooms[room]; ok { if _, exists := clients[client]; exists { delete(clients, client) close(client.send) if len(clients) == 0 { delete(h.rooms, room) } } } total := 0 for _, r := range h.rooms { total += len(r) } h.mu.Unlock() slog.Info("ws client disconnected", "workspace_id", room, "total_clients", total) case message := <-h.broadcast: // Global broadcast for daemon events (no workspace filtering) h.mu.RLock() var slow []*Client for _, clients := range h.rooms { for client := range clients { select { case client.send <- message: default: slow = append(slow, client) } } } h.mu.RUnlock() if len(slow) > 0 { h.mu.Lock() for _, client := range slow { room := client.workspaceID if clients, ok := h.rooms[room]; ok { if _, exists := clients[client]; exists { delete(clients, client) close(client.send) if len(clients) == 0 { delete(h.rooms, room) } } } } h.mu.Unlock() } } } } // BroadcastToWorkspace sends a message only to clients in the given workspace. func (h *Hub) BroadcastToWorkspace(workspaceID string, message []byte) { h.mu.RLock() clients := h.rooms[workspaceID] var slow []*Client for client := range clients { select { case client.send <- message: default: slow = append(slow, client) } } h.mu.RUnlock() // Remove slow clients under write lock if len(slow) > 0 { h.mu.Lock() for _, client := range slow { if room, ok := h.rooms[workspaceID]; ok { if _, exists := room[client]; exists { delete(room, client) close(client.send) if len(room) == 0 { delete(h.rooms, workspaceID) } } } } h.mu.Unlock() } } // Broadcast sends a message to all connected clients (used for daemon events). func (h *Hub) Broadcast(message []byte) { h.broadcast <- message } // HandleWebSocket upgrades an HTTP connection to WebSocket with JWT auth. func HandleWebSocket(hub *Hub, mc MembershipChecker, w http.ResponseWriter, r *http.Request) { tokenStr := r.URL.Query().Get("token") workspaceID := r.URL.Query().Get("workspace_id") if tokenStr == "" || workspaceID == "" { http.Error(w, `{"error":"token and workspace_id required"}`, http.StatusUnauthorized) return } // Validate JWT token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, jwt.ErrSignatureInvalid } return auth.JWTSecret(), nil }) if err != nil || !token.Valid { http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(w, `{"error":"invalid claims"}`, http.StatusUnauthorized) return } userID, ok := claims["sub"].(string) if !ok || strings.TrimSpace(userID) == "" { http.Error(w, `{"error":"invalid claims"}`, http.StatusUnauthorized) return } // Verify user is a member of the workspace if !mc.IsMember(r.Context(), userID, workspaceID) { http.Error(w, `{"error":"not a member of this workspace"}`, http.StatusForbidden) return } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { slog.Error("websocket upgrade failed", "error", err) return } client := &Client{ hub: hub, conn: conn, send: make(chan []byte, 256), userID: userID, workspaceID: workspaceID, } hub.register <- client go client.writePump() go client.readPump() } func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() for { _, _, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { slog.Debug("websocket read error", "error", err, "user_id", c.userID, "workspace_id", c.workspaceID) } break } // TODO: Route inbound messages to appropriate handlers slog.Debug("ws message received", "user_id", c.userID, "workspace_id", c.workspaceID) } } func (c *Client) writePump() { defer c.conn.Close() for message := range c.send { if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { slog.Warn("websocket write error", "error", err) return } } }