diff --git a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py index af7481f0f..eb2c30d54 100644 --- a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py +++ b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py @@ -1,14 +1,15 @@ import sys +import time from contextlib import contextmanager from typing import Optional import click -from anyio import BrokenResourceError, EndOfStream, create_task_group, open_file +from anyio import BrokenResourceError, EndOfStream, create_task_group, open_file, sleep, to_thread from anyio.streams.file import FileReadStream from jumpstarter_driver_network.adapters import PexpectAdapter from pexpect.fdpexpect import fdspawn -from .console import Console +from .console import Console, ConsoleStreamDrop from jumpstarter.client import DriverClient from jumpstarter.client.decorators import driver_click_group @@ -125,6 +126,31 @@ async def _stdin_to_serial(self, stream) -> tuple[int, int]: return bytes_read, bytes_sent + def _find_power_client(self): + root = getattr(self, 'root', None) + if root is None: + return None + return self._search_power(root) + + def _search_power(self, client): + for child in client.children.values(): + if hasattr(child, "cycle") or (hasattr(child, "on") and hasattr(child, "off")): + return child + result = self._search_power(child) + if result is not None: + return result + return None + + def _make_power_cycle(self, power_client): + async def _cycle(): + if hasattr(power_client, "cycle"): + await to_thread.run_sync(power_client.cycle) + else: + await to_thread.run_sync(power_client.off) + await sleep(2) + await to_thread.run_sync(power_client.on) + return _cycle + def cli(self): # noqa: C901 @driver_click_group(self) def base(): @@ -134,9 +160,23 @@ def base(): @base.command() def start_console(): """Start serial port console""" + power_client = self._find_power_client() + on_power_cycle = self._make_power_cycle(power_client) if power_client is not None else None click.echo("\nStarting serial port console ... exit with CTRL+B x 3 times\n") - console = Console(serial_client=self) - console.run() + if on_power_cycle is not None: + click.echo("Power cycle: CTRL+] x 3 times\n") + retries = 0 + while retries < 30: + console = Console(serial_client=self, on_power_cycle=on_power_cycle) + try: + console.run() + break + except ConsoleStreamDrop: + click.echo("\r\nSerial connection lost, reconnecting...\n", err=True) + retries += 1 + time.sleep(1) + else: + click.echo("\nSerial connection lost (reconnect attempts exhausted).\n", err=True) @base.command() @click.option( diff --git a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client_test.py b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client_test.py new file mode 100644 index 000000000..dfefb4df9 --- /dev/null +++ b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client_test.py @@ -0,0 +1,32 @@ +import threading +from unittest.mock import MagicMock + +from .driver import PySerial +from jumpstarter.common.utils import serve + + +def test_find_power_client_no_root(): + with serve(PySerial(url="loop://")) as client: + assert client._find_power_client() is None + + +def test_find_power_client_with_cycle(): + power = MagicMock(spec=["cycle", "children"]) + power.children = {} + root = MagicMock(spec=["children"]) + root.children = {"power": power} + + with serve(PySerial(url="loop://")) as client: + object.__setattr__(client, "root", root) + assert client._find_power_client() is power + + +def test_make_power_cycle_calls_cycle(): + called = threading.Event() + power = MagicMock() + power.cycle = MagicMock(side_effect=lambda: called.set()) + + with serve(PySerial(url="loop://")) as client: + cycle_fn = client._make_power_cycle(power) + client.portal.call(cycle_fn) + assert called.is_set() diff --git a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console.py b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console.py index b315b2f55..f27e2de12 100644 --- a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console.py +++ b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console.py @@ -1,9 +1,10 @@ import sys import termios import tty +from collections.abc import Awaitable, Callable from contextlib import contextmanager -from anyio import create_task_group +from anyio import EndOfStream, create_task_group from anyio.streams.file import FileReadStream, FileWriteStream from jumpstarter.client import DriverClient @@ -13,9 +14,15 @@ class ConsoleExit(Exception): pass +class ConsoleStreamDrop(Exception): + """Serial stream dropped; caller may reconnect.""" + pass + + class Console: - def __init__(self, serial_client: DriverClient): + def __init__(self, serial_client: DriverClient, on_power_cycle: Callable[[], Awaitable[None]] | None = None): self.serial_client = serial_client + self.on_power_cycle = on_power_cycle def run(self): with self.setraw(): @@ -31,32 +38,50 @@ def setraw(self): termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, original) async def __run(self): - async with self.serial_client.stream_async(method="connect") as stream: - try: - async with create_task_group() as tg: - tg.start_soon(self.__serial_to_stdout, stream) - tg.start_soon(self.__stdin_to_serial, stream) - except* ConsoleExit: - pass + try: + async with self.serial_client.stream_async(method="connect") as stream: + try: + async with create_task_group() as tg: + tg.start_soon(self.__serial_to_stdout, stream) + tg.start_soon(self.__stdin_to_serial, stream) + except* ConsoleExit: + pass + except* ConsoleStreamDrop: + raise ConsoleStreamDrop() from None + except EndOfStream: + raise ConsoleStreamDrop() from None async def __serial_to_stdout(self, stream): stdout = FileWriteStream(sys.stdout.buffer) - while True: - data = await stream.receive() - await stdout.send(data) - sys.stdout.flush() + try: + while True: + data = await stream.receive() + await stdout.send(data) + sys.stdout.flush() + except EndOfStream: + raise ConsoleStreamDrop() from None async def __stdin_to_serial(self, stream): stdin = FileReadStream(sys.stdin.buffer) ctrl_b_count = 0 + ctrl_bracket_count = 0 # Ctrl-] x3 triggers power cycle while True: data = await stdin.receive(max_bytes=1) if not data: continue if data == b"\x02": # Ctrl-B ctrl_b_count += 1 + ctrl_bracket_count = 0 if ctrl_b_count == 3: raise ConsoleExit + elif data == b"\x1d": # Ctrl-] + ctrl_bracket_count += 1 + ctrl_b_count = 0 + if ctrl_bracket_count == 3 and self.on_power_cycle is not None: + await self.on_power_cycle() + ctrl_bracket_count = 0 + continue else: ctrl_b_count = 0 + ctrl_bracket_count = 0 await stream.send(data) diff --git a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console_test.py b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console_test.py new file mode 100644 index 000000000..4e8f691d4 --- /dev/null +++ b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/console_test.py @@ -0,0 +1,91 @@ +import os +import threading +import time +from unittest.mock import MagicMock, patch + +from .console import Console +from .driver import PySerial +from jumpstarter.common.utils import serve + + +def _start_console(client, on_power_cycle=None): + """Run Console.run() in a thread with a PTY substituted for stdin. + + Returns (master_fd, thread, result_dict). Write keypresses to master_fd; + the result dict gets an 'exc' key if the console thread raises. + """ + master_fd, slave_fd = os.openpty() + slave_file = os.fdopen(slave_fd, "rb", buffering=0) + + mock_stdin = MagicMock() + mock_stdin.fileno.return_value = slave_fd + mock_stdin.buffer = slave_file + + result = {} + + def _run(): + with patch("sys.stdin", mock_stdin): + console = Console(serial_client=client, on_power_cycle=on_power_cycle) + try: + console.run() + except Exception as e: + result["exc"] = e + slave_file.close() + + t = threading.Thread(target=_run, daemon=True) + t.start() + return master_fd, t, result + + +def test_ctrl_b_exits(): + with serve(PySerial(url="loop://")) as client: + master_fd, t, result = _start_console(client) + try: + time.sleep(0.1) + os.write(master_fd, b"a") + os.write(master_fd, b"\x02\x02\x02") + t.join(timeout=5) + finally: + os.close(master_fd) + + assert not t.is_alive(), "console did not exit after Ctrl-B x3" + assert "exc" not in result + + +def test_ctrl_bracket_triggers_power_cycle(): + power_cycled = threading.Event() + + async def on_power_cycle(): + power_cycled.set() + + with serve(PySerial(url="loop://")) as client: + master_fd, t, result = _start_console(client, on_power_cycle=on_power_cycle) + try: + time.sleep(0.1) + os.write(master_fd, b"\x1d\x1d\x1d") + assert power_cycled.wait(timeout=5), "power cycle was not triggered" + assert t.is_alive(), "console exited after power cycle" + os.write(master_fd, b"\x02\x02\x02") + t.join(timeout=5) + finally: + os.close(master_fd) + + assert not t.is_alive() + assert "exc" not in result + + +def test_ctrl_bracket_without_power_client(): + with serve(PySerial(url="loop://")) as client: + master_fd, t, result = _start_console(client, on_power_cycle=None) + try: + time.sleep(0.1) + os.write(master_fd, b"\x1d\x1d\x1d") + time.sleep(0.1) + assert t.is_alive(), "console exited unexpectedly on Ctrl-] without power client" + os.write(master_fd, b"\x02\x02\x02") + t.join(timeout=5) + finally: + os.close(master_fd) + + assert not t.is_alive() + assert "exc" not in result diff --git a/python/packages/jumpstarter/jumpstarter/client/client.py b/python/packages/jumpstarter/jumpstarter/client/client.py index 980f2da39..8cb212ed8 100644 --- a/python/packages/jumpstarter/jumpstarter/client/client.py +++ b/python/packages/jumpstarter/jumpstarter/client/client.py @@ -97,7 +97,6 @@ async def client_from_channel( stub = MultipathExporterStub([channel]) response = await stub.GetReport(empty_pb2.Empty()) - for index, report in enumerate(response.reports): topo[index] = [] @@ -134,4 +133,14 @@ async def client_from_channel( clients[index] = client + root_client = next(reversed(clients.values())) + + def _iter_all(client): + yield client + for child in client.children.values(): + yield from _iter_all(child) + + for c in _iter_all(root_client): + object.__setattr__(c, 'root', root_client) + return clients.popitem(last=True)[1]