Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 61 additions & 46 deletions src/comet/emulator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
lcr:
model: urn:comet:model:keysight:e4980a
port: 11002
# User specific emulator
my_instr:
module: local_project.my_instr_emulator
port: 12001
```

Loading a configuration filename (default filenames are `emulators.yaml` and `emulators.yml`).
Expand All @@ -28,19 +24,19 @@
"""

import argparse
import asyncio
import contextlib
import logging
import os
import signal
import threading
from typing import Any

import schema
import yaml

from .. import __version__

from .emulator import emulator_factory
from .tcpserver import TCPServer, TCPServerThread, TCPServerContext
from .tcpserver import TCPServer, TCPServerContext

default_config_filenames: list[str] = ["emulators.yaml", "emulators.yml"]
default_host: str = "localhost"
Expand Down Expand Up @@ -68,7 +64,7 @@ def normalize_termination(value: str) -> str:

config_schema = schema.Schema(
{
schema.Optional("version"): str, # deprecated
schema.Optional("version"): str,
"emulators": {
str: {
schema.Optional(schema.Or("model", "module")): str,
Expand All @@ -93,10 +89,15 @@ def load_config(filename: str) -> dict[str, Any]:
with open(filename) as fp:
data = yaml.safe_load(fp)
config = validate_config(data or {})
# Set defaults
for params in config.get("emulators", {}).values():
for name, params in config.get("emulators", {}).items():
if "model" in params and "module" in params:
raise KeyError("keys 'model' and 'module' are exclusive")
if "module" in params:
logging.warning(
"Emulator %r uses deprecated config key 'module'; "
"use 'model' instead. Support exists only for backward compatibility.",
name,
)
params.setdefault("host", default_host)
params.setdefault("termination", default_termination)
params.setdefault("request_delay", default_request_delay)
Expand Down Expand Up @@ -132,64 +133,78 @@ def locate_config_filename() -> str:
raise RuntimeError("No config file found.")


def event_loop() -> None:
"""Blocks execution until termination or interrupt from keyboard signal."""
e = threading.Event()

def handle_event(signum, frame):
e.set()

signal.signal(signal.SIGTERM, handle_event)
signal.signal(signal.SIGINT, handle_event)
e.wait()


def main() -> None:
async def main() -> None:
args = parse_args()

logging.basicConfig(level=logging.INFO)

config = load_config(args.filename or locate_config_filename())

threads = []
servers: list[TCPServer] = []

for name, params in config.get("emulators", {}).items():
model = params.get("model") or params.get("module") # fallback for comet<1.5
model = params.get("model") or params.get("module") # fallback for comet<1.5
host = params.get("host")
port = params.get("port")
termination_bytes = params.get("termination").encode()
request_delay = params.get("request_delay")
options = params.get("options", {})
address = host, port

emulator = emulator_factory(model)()
emulator.options.update(options)

context = TCPServerContext(
name=name,
emulator=emulator,
termination=termination_bytes,
request_delay=request_delay,
logger=logging.getLogger(name),
)
server = TCPServer(address, context)
threads.append(TCPServerThread(server))

for thread in threads:
host, port, *_ = thread.server.server_address # IPv4/IPv6
thread.server.context.logger.info("starting... %s:%s", host, port)
thread.start()

def handle_event(signum, frame):
for thread in threads:
host, port, *_ = thread.server.server_address # IPv4/IPv6
thread.server.context.logger.info("stopping... %s:%s", host, port)
thread.shutdown()

signal.signal(signal.SIGTERM, handle_event)
signal.signal(signal.SIGINT, handle_event)

for thread in threads:
thread.join()
server = TCPServer((host, port), context)
await server.start()
servers.append(server)

for server in servers:
host, port = server.server_address
server.context.logger.info("starting... %s:%s", host, port)

stop_event = asyncio.Event()

def request_shutdown() -> None:
if not stop_event.is_set():
for server in servers:
host, port = server.server_address
server.context.logger.info("stopping... %s:%s", host, port)
stop_event.set()

loop = asyncio.get_running_loop()

for sig in (signal.SIGTERM, signal.SIGINT):
try:
loop.add_signal_handler(sig, request_shutdown)
except NotImplementedError:
# Not supported on some platforms (notably parts of Windows).
# In that case, asyncio.run() will still surface Ctrl+C as
# KeyboardInterrupt, handled outside main().
...

tasks = [asyncio.create_task(server.serve_forever()) for server in servers]

try:
await stop_event.wait()
finally:
for server in servers:
await server.shutdown()
for task in tasks:
task.cancel()
for task in tasks:
with contextlib.suppress(asyncio.CancelledError):
await task


if __name__ == "__main__":
main()
try:
asyncio.run(main())
except KeyboardInterrupt:
# Fallback for platforms where asyncio signal handlers are unavailable.
...
Loading
Loading