| # Copyright 2023 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| """Code to allow tests to communicate via a websocket server.""" |
| |
| import logging |
| import threading |
| |
| import websockets # pylint: disable=import-error |
| import websockets.sync.server as sync_server # pylint: disable=import-error |
| |
| WEBSOCKET_PORT_TIMEOUT_SECONDS = 10 |
| WEBSOCKET_SETUP_TIMEOUT_SECONDS = 5 |
| WEBSOCKET_CLOSE_TIMEOUT_SECONDS = 2 |
| SERVER_SHUTDOWN_TIMEOUT_SECONDS = 5 |
| |
| # The client (Chrome) should never be closing the connection. If it does, it's |
| # indicative of something going wrong like a renderer crash. |
| ClientClosedConnectionError = websockets.exceptions.ConnectionClosedOK |
| |
| # Alias for readability. |
| WebsocketReceiveMessageTimeoutError = TimeoutError |
| |
| |
| class WebsocketServer(): |
| |
| def __init__(self): |
| """Server that abstracts the websocket library under the hood. |
| |
| Only supports one active connection at a time. |
| """ |
| self.server_port = None |
| self.websocket = None |
| self.connection_stopper_event = None |
| self.connection_closed_event = None |
| self.port_set_event = threading.Event() |
| self.connection_received_event = threading.Event() |
| self._server_thread = None |
| |
| def StartServer(self) -> None: |
| """Starts the websocket server on a separate thread.""" |
| assert self._server_thread is None, 'Server already running' |
| self._server_thread = _ServerThread(self) |
| self._server_thread.daemon = True |
| self._server_thread.start() |
| got_port = self.port_set_event.wait(WEBSOCKET_PORT_TIMEOUT_SECONDS) |
| if not got_port: |
| raise RuntimeError('Websocket server did not provide a port') |
| # Note: We don't need to set up any port forwarding for remote platforms |
| # after this point due to Telemetry's use of --proxy-server to send all |
| # traffic through the TsProxyServer. This causes network traffic to pop out |
| # on the host, which means that using the websocket server's port directly |
| # works. |
| |
| def ClearCurrentConnection(self) -> None: |
| if self.connection_stopper_event: |
| self.connection_stopper_event.set() |
| closed = self.connection_closed_event.wait( |
| WEBSOCKET_CLOSE_TIMEOUT_SECONDS) |
| if not closed: |
| raise RuntimeError('Websocket connection did not close') |
| self.connection_stopper_event = None |
| self.connection_closed_event = None |
| self.websocket = None |
| self.connection_received_event.clear() |
| |
| def WaitForConnection(self, timeout: float | None = None) -> None: |
| if self.websocket: |
| return |
| timeout = timeout or WEBSOCKET_SETUP_TIMEOUT_SECONDS |
| self.connection_received_event.wait(timeout) |
| if not self.websocket: |
| raise RuntimeError('Websocket connection was not established') |
| |
| def StopServer(self) -> None: |
| self.ClearCurrentConnection() |
| self._server_thread.shutdown() |
| self._server_thread.join(SERVER_SHUTDOWN_TIMEOUT_SECONDS) |
| if self._server_thread.is_alive(): |
| logging.error( |
| 'Websocket server did not shut down properly - this might be ' |
| 'indicative of an issue in the test harness') |
| |
| def Send(self, message: str) -> None: |
| self.websocket.send(message) |
| |
| def Receive(self, timeout: int) -> str: |
| try: |
| return self.websocket.recv(timeout) |
| except TimeoutError as e: |
| raise WebsocketReceiveMessageTimeoutError( |
| f'Timed out after {timeout} seconds waiting for websocket message' |
| ) from e |
| |
| |
| class _ServerThread(threading.Thread): |
| def __init__(self, server_instance: WebsocketServer, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._server_instance = server_instance |
| self.websocket_server = None |
| |
| def run(self) -> None: |
| StartWebsocketServer(self, self._server_instance) |
| |
| def shutdown(self) -> None: |
| self.websocket_server.shutdown() |
| |
| |
| def StartWebsocketServer(server_thread: _ServerThread, |
| server_instance: WebsocketServer) -> None: |
| def HandleWebsocketConnection( |
| websocket: sync_server.ServerConnection) -> None: |
| # We only allow one active connection - if there are multiple, something is |
| # wrong. |
| assert server_instance.connection_stopper_event is None |
| assert server_instance.connection_closed_event is None |
| assert server_instance.websocket is None |
| server_instance.connection_stopper_event = threading.Event() |
| server_instance.connection_closed_event = threading.Event() |
| # Keep our own reference in case the server clears its reference before the |
| # await finishes. |
| connection_stopper_event = server_instance.connection_stopper_event |
| connection_closed_event = server_instance.connection_closed_event |
| server_instance.websocket = websocket |
| server_instance.connection_received_event.set() |
| connection_stopper_event.wait() |
| connection_closed_event.set() |
| |
| with sync_server.serve(HandleWebsocketConnection, '127.0.0.1', 0) as server: |
| server_thread.websocket_server = server |
| server_instance.server_port = server.socket.getsockname()[1] |
| server_instance.port_set_event.set() |
| server.serve_forever() |