diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ffd82fba..3287d010 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -233,6 +233,9 @@ class Server: """ + connections: set[ServerConnection] + """Set of active connections.""" + def __init__( self, socket: socket.socket, @@ -246,6 +249,8 @@ def __init__( self.logger = logger if sys.platform != "win32": self.shutdown_watcher, self.shutdown_notifier = os.pipe() + self.connections = set() + self.connections_lock = threading.Lock() def serve_forever(self) -> None: """ @@ -293,6 +298,10 @@ def shutdown(self) -> None: self.socket.close() if sys.platform != "win32": os.write(self.shutdown_notifier, b"x") + # Close all active connections gracefully. + with self.connections_lock: + for connection in self.connections: + connection.close(CloseCode.GOING_AWAY) def fileno(self) -> int: """ @@ -520,6 +529,9 @@ def handler(websocket): # Define request handler + # Use a list to allow conn_handler to access server before it's assigned. + server_ref: list[Server] = [] + def conn_handler(sock: socket.socket, addr: Any) -> None: # Calculate timeouts on the TLS and WebSocket handshakes. # The TLS timeout must be set on the socket, then removed @@ -587,6 +599,11 @@ def protocol_select_subprotocol( sock.close() return + # Register connection for tracking. + server = server_ref[0] + with server.connections_lock: + server.connections.add(connection) + try: try: connection.handshake( @@ -618,10 +635,16 @@ def protocol_select_subprotocol( except Exception: # pragma: no cover # Don't leak sockets on unexpected errors. sock.close() + finally: + # Unregister connection. + with server.connections_lock: + server.connections.discard(connection) # Initialize server - return Server(sock, conn_handler, logger) + server = Server(sock, conn_handler, logger) + server_ref.append(server) + return server def unix_serve( diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index d04d1859..0a6ba3be 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -288,6 +288,24 @@ def test_shutdown(self): with self.assertRaises(OSError): server.socket.accept() + def test_shutdown_closes_connections(self): + """Server closes active connections on shutdown.""" + with run_server() as server: + with connect(get_uri(server)) as client: + # Connection is open + self.assertEval(client, "ws.protocol.state.name", "OPEN") + # Check connection is tracked + self.assertEqual(len(server.connections), 1) + # Shutdown server + server.shutdown() + # Connection should be closed with GOING_AWAY + with self.assertRaises(ConnectionClosedOK) as raised: + client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + def test_handshake_fails(self): """Server receives connection from client but the handshake fails."""