package main import ( "encoding/json" "log" "net/http" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{} func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Upgrading connection to WS failed: %v.", err) return } defer conn.Close() sendChan := make(chan ServerWebsocketPacket) defer close(sendChan) // Goroutine for sending packets. go func(sendChan <-chan ServerWebsocketPacket) { for packet := range sendChan { message := struct { Type string `json:"type"` Payload any `json:"payload"` }{ Type: packet.Type(), Payload: packet, } messageData, err := json.Marshal(message) if err != nil { log.Printf("Failed to marshal websocket packet: %v.", err) continue } conn.WriteMessage(websocket.TextMessage, messageData) } }(sendChan) // Register listener on document queue updates. s.Documents.RegisterListener(sendChan) defer s.Documents.UnregisterListener(sendChan) // Main loop that receives packets. for { messageType, data, err := conn.ReadMessage() if err != nil { log.Printf("Reading WS message failed: %v.", err) break } switch messageType { case websocket.CloseMessage: log.Printf("Connection %v closed.", conn.LocalAddr()) return case websocket.TextMessage: //log.Printf("Message from %v: %s.", conn.LocalAddr(), data) var message struct { Type string `json:"type"` Payload json.RawMessage `json:"payload"` } if err := json.Unmarshal(data, &message); err != nil { log.Printf("Failed to marshal websocket packet from client %v: %v.", conn.LocalAddr(), err) return } prototype, ok := serverWebsocketPacketRegistry[message.Type] if !ok { log.Printf("Unknown websocket packet type %q from client %v.", message.Type, conn.LocalAddr()) return } if err := json.Unmarshal(message.Payload, prototype); err != nil { log.Printf("Failed to marshal websocket packet payload from client %v: %v.", conn.LocalAddr(), err) return } switch packet := prototype.(type) { case *ServerWebsocketPacketQueueDelete: s.Documents.Lock() s.Documents.Delete(packet.IDs...) s.Documents.Unlock() case *ServerWebsocketPacketQueueShift: s.Documents.Lock() s.Documents.Shift(packet.Offset, packet.IDs...) s.Documents.Unlock() case *ServerWebsocketPacketQueueSplit: s.Documents.Lock() s.Documents.Split(packet.IDs...) s.Documents.Unlock() default: log.Printf("Websocket client %q sent unsupported packet type %T.", conn.LocalAddr(), prototype) return } } } }