diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 4779f4c..f09acb3 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -16,7 +16,7 @@ from fastapi.openapi.utils import get_openapi from .authentication import Mode -from .console_output import CollectPublishedConsoleOutput +from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired from .resources import SERVER_RESOURCES as SR @@ -346,6 +346,11 @@ async def purge_expired_sessions_and_api_keys(): SR.set_console_output_loader(CollectPublishedConsoleOutput(rm_ref=RM)) SR.console_output_loader.start() + SR.set_console_output_stream(ConsoleOutputStream(rm_ref=RM)) + SR.console_output_stream.start() + SR.console_output_loader.subscribe(SR.console_output_stream.add_message) + SR.set_system_info_stream(SystemInfoStream(rm_ref=RM)) + SR.system_info_stream.start() # Import module with custom code module_names_str = os.getenv("QSERVER_CUSTOM_MODULES", None) @@ -387,6 +392,8 @@ async def purge_expired_sessions_and_api_keys(): async def shutdown_event(): await SR.RM.close() await SR.console_output_loader.stop() + await SR.console_output_stream.stop() + await SR.system_info_stream.stop() @lru_cache(1) def override_get_authenticators(): diff --git a/bluesky_httpserver/console_output.py b/bluesky_httpserver/console_output.py index 2743f8b..24b142d 100644 --- a/bluesky_httpserver/console_output.py +++ b/bluesky_httpserver/console_output.py @@ -1,4 +1,5 @@ import asyncio +import inspect import json import logging import queue @@ -51,6 +52,9 @@ def __init__(self, *, rm_ref): self._background_task_stopped = asyncio.Event() self._background_task_stopped.set() + self._callbacks = [] + self._callbacks_async = [] + @property def queues_set(self): """ @@ -67,6 +71,22 @@ def text_buffer_uid(self): async def get_text_buffer(self, n_lines): return await self._RM.console_monitor.text(n_lines) + def subscribe(self, cb): + """ + Add a function or a coroutine to the list of callbacks. The callbacks must accept + message as a parameter: cb(msg) + """ + if inspect.iscoroutinefunction(cb): + self._callbacks_async.append(cb) + else: + self._callbacks.append(cb) + + def unsubscribe(self, cb): + if inspect.iscoroutinefunction(cb): + self._callbacks_async.remove(cb) + else: + self._callbacks.remove(cb) + def get_new_msgs(self, last_msg_uid): msg_list = [] try: @@ -94,6 +114,10 @@ async def _load_msgs_task(self): try: msg = await self._RM.console_monitor.next_msg(timeout=0.5) self._add_message(msg=msg) + for cb in self._callbacks: + cb(msg) + for cb in self._callbacks_async: + await cb(msg) except self._RM.RequestTimeoutError: pass self._background_task_stopped.set() @@ -167,3 +191,142 @@ def __init__(self, content_class, *args, **kwargs): def __del__(self): del self._content + + +class ConsoleOutputStream: + def __init__(self, *, rm_ref): + self._queues = {} + self._queue_max_size = 1000 + + @property + def queues(self): + return self._queues + + def add_queue(self, key): + """ + Add a new queue to the dictionary of queues. The key is a reference to the socket for + for connection with the client. + """ + queue = asyncio.Queue(maxsize=self._queue_max_size) + self._queues[key] = queue + return queue + + def remove_queue(self, key): + """ + Remove the queue identified by the key from the dictionary of queues. + """ + if key in self._queues: + del self._queues[key] + + async def add_message(self, msg): + msg_json = json.dumps(msg) + for q in self._queues.values(): + # Protect from overflow. It's ok to discard old messages. + if q.full(): + q.get_nowait() + await q.put(msg_json) + + def start(self): + pass + + async def stop(self): + pass + + +class SystemInfoStream: + def __init__(self, *, rm_ref): + self._RM = rm_ref + self._queues_status = {} + self._queues_info = {} + self._background_task = None + self._background_task_running = False + self._background_task_stopped = asyncio.Event() + self._background_task_stopped.set() + self._num = 0 + self._queue_max_size = 1000 + + @property + def background_task_running(self): + return self._background_task_running + + @property + def queues_status(self): + return self._queues_status + + @property + def queues_info(self): + return self._queues_info + + def add_queue_status(self, key): + """ + Add a new queue to the dictionary of queues. The key is a reference to the socket for + for connection with the client. + """ + queue = asyncio.Queue(maxsize=self._queue_max_size) + self._queues_status[key] = queue + return queue + + def add_queue_info(self, key): + """ + Add a new queue to the dictionary of queues. The key is a reference to the socket for + for connection with the client. + """ + queue = asyncio.Queue(maxsize=self._queue_max_size) + self._queues_info[key] = queue + return queue + + def remove_queue_status(self, key): + """ + Remove the queue identified by the key from the dictionary of queues. + """ + if key in self._queues_status: + del self._queues_status[key] + + def remove_queue_info(self, key): + """ + Remove the queue identified by the key from the dictionary of queues. + """ + if key in self._queues_info: + del self._queues_info[key] + + def _start_background_task(self): + if not self._background_task_running: + self._background_task = asyncio.create_task(self._load_msgs_task()) + + async def _stop_background_task(self): + self._background_task_running = False + await self._background_task_stopped.wait() + + async def _load_msgs_task(self): + self._background_task_stopped.clear() + self._background_task_running = True + while self._background_task_running: + try: + msg = await self._RM.system_info_monitor.next_msg(timeout=0.5) + + if isinstance(msg, dict) and "msg" in msg: + msg_json = json.dumps(msg) + # ALL 'info' messages + for q in self._queues_info.values(): + # Protect from overflow. It's ok to discard old messages. + if q.full(): + q.get_nowait() + await q.put(msg_json) + if isinstance(msg["msg"], dict) and "status" in msg["msg"]: + # ONLY 'status' messages + for q in self._queues_status.values(): + # Protect from overflow. It's ok to discard old messages. + if q.full(): + q.get_nowait() + await q.put(msg_json) + except self._RM.RequestTimeoutError: + pass + self._background_task_stopped.set() + + def start(self): + self._RM.system_info_monitor.enable() + self._start_background_task() + + async def stop(self): + await self._stop_background_task() + await self._RM.system_info_monitor.disable_wait() diff --git a/bluesky_httpserver/resources.py b/bluesky_httpserver/resources.py index c352a8d..1dca2ca 100644 --- a/bluesky_httpserver/resources.py +++ b/bluesky_httpserver/resources.py @@ -3,6 +3,7 @@ def __init__(self): self._RM = None self._custom_code_modules = [] self._console_output_loader = None + self._stop_server = False def set_RM(self, RM): self._RM = RM @@ -37,5 +38,27 @@ def console_output_loader(self): def console_output_loader(self, _): raise RuntimeError("Attempting to set read-only property 'console_output_loader'") + def set_console_output_stream(self, console_output_stream): + self._console_output_stream = console_output_stream + + @property + def console_output_stream(self): + return self._console_output_stream + + @console_output_stream.setter + def console_output_stream(self, _): + raise RuntimeError("Attempting to set read-only property 'console_output_stream'") + + def set_system_info_stream(self, system_info_stream): + self._system_info_stream = system_info_stream + + @property + def system_info_stream(self): + return self._system_info_stream + + @system_info_stream.setter + def system_info_stream(self, _): + raise RuntimeError("Attempting to set read-only property 'system_info_stream'") + SERVER_RESOURCES = _ServerResources() diff --git a/bluesky_httpserver/routers/core_api.py b/bluesky_httpserver/routers/core_api.py index 3cd00aa..7eaa74e 100644 --- a/bluesky_httpserver/routers/core_api.py +++ b/bluesky_httpserver/routers/core_api.py @@ -6,7 +6,7 @@ import pydantic from bluesky_queueserver.manager.conversions import simplify_plan_descriptions, spreadsheet_to_plan_list -from fastapi import APIRouter, Depends, File, Form, Request, Security, UploadFile +from fastapi import APIRouter, Depends, File, Form, Request, Security, UploadFile, WebSocket, WebSocketDisconnect from packaging import version if version.parse(pydantic.__version__) < version.parse("2.0.0"): @@ -1098,3 +1098,98 @@ def console_output_update(payload: dict, principal=Security(get_current_principa process_exception() return response + + +class WebSocketMonitor: + """ + Works for sockets that only send data to clients (not receive). + + The class monitors the status of a socket connection. The property 'is_alive' returns True + until the socket is disconnected. The purpose of the class is to break the loop in the + implementation of the socket that only sends data to a client when the application + is closed. If there is no data to send, the loop continues to run indefinitely and + prevents the application from closing properly. No better solution was found. + """ + + def __init__(self, websocket): + self._websocket = websocket + self._is_alive = True + self._task_ref = None + + async def _task(self): + while True: + try: + await asyncio.sleep(1) + try: + # The following will raise an exception if the socket is disconnected. + await asyncio.wait_for(self._websocket.receive(), timeout=0.01) + except asyncio.TimeoutError: + # The socket is still connected. + pass + except Exception: + self._is_alive = False + break + + def start(self): + self._task_ref = asyncio.create_task(self._task()) + + @property + def is_alive(self): + return self._is_alive + + +@router.websocket("/console_output/ws") +async def console_output_ws(websocket: WebSocket): + await websocket.accept() + q = SR.console_output_stream.add_queue(websocket) + wsmon = WebSocketMonitor(websocket) + wsmon.start() + try: + while wsmon.is_alive: + try: + msg = await asyncio.wait_for(q.get(), timeout=1) + await websocket.send_text(msg) + except asyncio.TimeoutError: + pass + except WebSocketDisconnect: + pass + finally: + SR.console_output_stream.remove_queue(websocket) + + +@router.websocket("/status/ws") +async def status_ws(websocket: WebSocket): + await websocket.accept() + q = SR.system_info_stream.add_queue_status(websocket) + wsmon = WebSocketMonitor(websocket) + wsmon.start() + try: + while wsmon.is_alive: + try: + msg = await asyncio.wait_for(q.get(), timeout=1) + await websocket.send_text(msg) + except asyncio.TimeoutError: + pass + except WebSocketDisconnect: + pass + finally: + SR.system_info_stream.remove_queue_status(websocket) + + +@router.websocket("/info/ws") +async def info_ws(websocket: WebSocket): + await websocket.accept() + q = SR.system_info_stream.add_queue_info(websocket) + wsmon = WebSocketMonitor(websocket) + wsmon.start() + try: + while wsmon.is_alive: + try: + msg = await asyncio.wait_for(q.get(), timeout=1) + await websocket.send_text(msg) + except asyncio.TimeoutError: + pass + except WebSocketDisconnect: + pass + finally: + SR.system_info_stream.remove_queue_info(websocket) diff --git a/bluesky_httpserver/tests/test_console_output.py b/bluesky_httpserver/tests/test_console_output.py index df3e877..1b87e53 100644 --- a/bluesky_httpserver/tests/test_console_output.py +++ b/bluesky_httpserver/tests/test_console_output.py @@ -7,6 +7,7 @@ import pytest import requests from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa F401 +from websockets.sync.client import connect from bluesky_httpserver.tests.conftest import ( # noqa F401 API_KEY_FOR_TESTS, @@ -336,3 +337,101 @@ def test_http_server_console_output_update_1( assert resp7["success"] is True, pprint.pformat(resp7) assert wait_for_environment_to_be_closed(timeout=10) + + +class _ReceiveConsoleOutputSocket(threading.Thread): + """ + Catch streaming console output by connecting to /console_output/ws socket and + save messages to the buffer. + """ + + def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs): + super().__init__(**kwargs) + self.received_data_buffer = [] + self._exit = False + self._api_key = api_key + + def run(self): + websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api/console_output/ws" + with connect(websocket_uri) as websocket: + while not self._exit: + try: + msg_json = websocket.recv(timeout=0.1, decode=False) + try: + msg = json.loads(msg_json) + self.received_data_buffer.append(msg) + except json.JSONDecodeError: + pass + except TimeoutError: + pass + + def stop(self): + """ + Call this method to stop the thread. Then send a request to the server so that some output + is printed in ``stdout``. + """ + self._exit = True + + def __del__(self): + self.stop() + + +@pytest.mark.parametrize("zmq_port", (None, 60619)) +def test_http_server_console_output_socket_1( + monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port # noqa F811 +): + """ + Test for ``/console_output/ws`` websocket + """ + # Start HTTP Server + if zmq_port is not None: + monkeypatch.setenv("QSERVER_ZMQ_INFO_ADDRESS", f"tcp://localhost:{zmq_port}") + fastapi_server_fs() + + # Start RE Manager + params = ["--zmq-publish-console", "ON"] + if zmq_port is not None: + params.extend(["--zmq-info-addr", f"tcp://*:{zmq_port}"]) + re_manager_cmd(params) + + rsc = _ReceiveConsoleOutputSocket() + rsc.start() + ttime.sleep(1) # Wait until the client connects to the socket + + resp1 = request_to_json( + "post", + "/queue/item/add", + json={"item": {"name": "count", "args": [["det1", "det2"]], "item_type": "plan"}}, + ) + assert resp1["success"] is True + assert resp1["qsize"] == 1 + assert resp1["item"]["name"] == "count" + assert resp1["item"]["args"] == [["det1", "det2"]] + assert "item_uid" in resp1["item"] + + # Wait until capture is complete (at least 2 message are expected) or timetout expires + ttime.sleep(10) + rsc.stop() + # Note, that some output from the server is is needed in order to exit the loop in the thread. + + resp2 = request_to_json("get", "/queue/get") + assert resp2["items"] != [] + assert len(resp2["items"]) == 1 + assert resp2["items"][0] == resp1["item"] + assert resp2["running_item"] == {} + + rsc.join() + + assert len(rsc.received_data_buffer) >= 2, pprint.pformat(rsc.received_data_buffer) + + # Verify that expected messages ('strings') are contained in captured messages. + expected_messages = {"Adding new item to the queue", "Item added"} + buffer = rsc.received_data_buffer + for msg in buffer: + for emsg in expected_messages.copy(): + if emsg in msg["msg"]: + expected_messages.remove(emsg) + + assert ( + not expected_messages + ), f"Messages {expected_messages} were not found in captured output: {pprint.pformat(buffer)}" diff --git a/bluesky_httpserver/tests/test_system_info_socket.py b/bluesky_httpserver/tests/test_system_info_socket.py new file mode 100644 index 0000000..b20c98c --- /dev/null +++ b/bluesky_httpserver/tests/test_system_info_socket.py @@ -0,0 +1,122 @@ +import json +import pprint +import threading +import time as ttime + +import pytest +from bluesky_queueserver.manager.tests.common import re_manager_cmd # noqa F401 +from websockets.sync.client import connect + +from bluesky_httpserver.tests.conftest import ( # noqa F401 + API_KEY_FOR_TESTS, + SERVER_ADDRESS, + SERVER_PORT, + fastapi_server_fs, + request_to_json, + set_qserver_zmq_encoding, + wait_for_environment_to_be_closed, + wait_for_environment_to_be_created, + wait_for_manager_state_idle, +) + + +class _ReceiveSystemInfoSocket(threading.Thread): + """ + Catch streaming console output by connecting to /console_output/ws socket and + save messages to the buffer. + """ + + def __init__(self, *, endpoint, api_key=API_KEY_FOR_TESTS, **kwargs): + super().__init__(**kwargs) + self.received_data_buffer = [] + self._exit = False + self._api_key = api_key + self._endpoint = endpoint + + def run(self): + websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}" + with connect(websocket_uri) as websocket: + while not self._exit: + try: + msg_json = websocket.recv(timeout=0.1, decode=False) + try: + msg = json.loads(msg_json) + self.received_data_buffer.append(msg) + except json.JSONDecodeError: + pass + except TimeoutError: + pass + + def stop(self): + """ + Call this method to stop the thread. Then send a request to the server so that some output + is printed in ``stdout``. + """ + self._exit = True + + def __del__(self): + self.stop() + + +@pytest.mark.parametrize("zmq_port", (None, 60619)) +@pytest.mark.parametrize("endpoint", ["/info/ws", "/status/ws"]) +def test_http_server_system_info_socket_1( + monkeypatch, re_manager_cmd, fastapi_server_fs, zmq_port, endpoint # noqa F811 +): + """ + Test for ``/info/ws`` and ``/status/ws`` websockets + """ + # Start HTTP Server + if zmq_port is not None: + monkeypatch.setenv("QSERVER_ZMQ_INFO_ADDRESS", f"tcp://localhost:{zmq_port}") + fastapi_server_fs() + + # Start RE Manager + params = ["--zmq-publish-console", "ON"] + if zmq_port is not None: + params.extend(["--zmq-info-addr", f"tcp://*:{zmq_port}"]) + re_manager_cmd(params) + + rsc = _ReceiveSystemInfoSocket(endpoint=endpoint) + rsc.start() + ttime.sleep(1) # Wait until the client connects to the socket + + resp1 = request_to_json("post", "/environment/open") + assert resp1["success"] is True, pprint.pformat(resp1) + + assert wait_for_environment_to_be_created(timeout=10) + + resp2b = request_to_json("post", "/environment/close") + assert resp2b["success"] is True, pprint.pformat(resp2b) + + assert wait_for_environment_to_be_closed(timeout=10) + + # Wait until capture is complete + ttime.sleep(2) + rsc.stop() + rsc.join() + + buffer = rsc.received_data_buffer + assert len(buffer) > 0 + for msg in buffer: + assert "time" in msg, msg + assert isinstance(msg["time"], float), msg + assert "msg" in msg + assert isinstance(msg["msg"], dict) + + if endpoint == "/status/ws": + for msg in buffer: + assert "status" in msg["msg"], msg + assert isinstance(msg["msg"]["status"], dict), msg + elif endpoint == "/info/ws": + for msg in buffer: + if "status" in msg["msg"]: + assert isinstance(msg["msg"]["status"], dict), msg + else: + assert False, f"Unknown endpoint: {endpoint}" + + # In the test we opened and then closed the environment, so let's check if it is reflected in + # the collected streamed status. + wrk_env_exists = [_["msg"]["status"]["worker_environment_exists"] for _ in buffer if "status" in _["msg"]] + assert wrk_env_exists.count(True) >= 0, wrk_env_exists + assert wrk_env_exists.count(False) >= 0, wrk_env_exists diff --git a/docs/source/control_re_manager.rst b/docs/source/control_re_manager.rst index 4a9c0e9..a082120 100644 --- a/docs/source/control_re_manager.rst +++ b/docs/source/control_re_manager.rst @@ -412,3 +412,35 @@ and generate the text output locally without repeatedly reloading the text buffer with each buffer update as in the case of ``/console_output`` API. :: http GET http://localhost:60610/api/console_output_update last_msg_uid= + + +WebSockets for streaming System Info and Console Output +------------------------------------------------------- + +The following WebSockets are currently implemented: + +- ``/console_output/ws`` - streaming of console output; + +- ``/info/ws`` - streaming of system info messages from RE Manager; + +- ``/status/ws`` - streaming of status messages from RE Manager. Status messages are sent each + time status is updated at RE Manager or at least once per second. Check status UID (part of + RE Manager status) to detect changes in status. + +For example, the console output stream may be received by connecting to the socket with +``ws://localhost:60610/api/console_output/ws`` URI. + +Currently ``/info/ws`` and ``/status/ws`` sockets are streaming the same sequence of RE Manager +status messages. Additional messages may be added to the system info stream in the future. + +Message format for console output messages:: + + {"time": , "msg": } + +Message format for system info messages:: + + {"time": , "msg": {: }} + +For example, the following format is used for status messages:: + + {"time": , "msg": {"status": {}}}