Skip to content

Commit c6244a7

Browse files
committed
more tests
1 parent ad5d7cd commit c6244a7

File tree

2 files changed

+56
-29
lines changed

2 files changed

+56
-29
lines changed

custom_components/pyscript/jupyter_kernel.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def __init__(self, config, ast_ctx, global_ctx_name):
196196
self.ast_ctx = ast_ctx
197197

198198
self.secure_key = str_to_bytes(self.config["key"])
199+
self.no_connect_timeout = self.config.get("no_connect_timeout", 30)
199200
self.signature_schemes = {"hmac-sha256": hashlib.sha256}
200201
self.auth = hmac.HMAC(
201202
self.secure_key, digestmod=self.signature_schemes[self.config["signature_scheme"]],
@@ -297,21 +298,15 @@ async def send(
297298
):
298299
"""Send message to the Jupyter client."""
299300
header = self.new_header(msg_type)
300-
if content is None:
301-
content = {}
302-
if parent_header is None:
303-
parent_header = {}
304-
if metadata is None:
305-
metadata = {}
306301

307302
def encode(msg):
308303
return str_to_bytes(json.dumps(msg))
309304

310305
msg_lst = [
311306
encode(header),
312-
encode(parent_header),
313-
encode(metadata),
314-
encode(content),
307+
encode(parent_header if parent_header else {}),
308+
encode(metadata if metadata else {}),
309+
encode(content if content else {}),
315310
]
316311
signature = self.msg_sign(msg_lst)
317312
parts = [DELIM, signature, msg_lst[0], msg_lst[1], msg_lst[2], msg_lst[3]]
@@ -447,6 +442,8 @@ async def shell_handler(self, shell_socket, wire_msg):
447442
)
448443

449444
elif msg["header"]["msg_type"] == "complete_request":
445+
root = ""
446+
words = set()
450447
code = msg["content"]["code"]
451448
posn = msg["content"]["cursor_pos"]
452449
match = self.completion_re.match(code[0:posn].lower())
@@ -456,9 +453,6 @@ async def shell_handler(self, shell_socket, wire_msg):
456453
words = words.union(await Function.service_completions(root))
457454
words = words.union(await Function.func_completions(root))
458455
words = words.union(self.ast_ctx.completions(root))
459-
else:
460-
root = ""
461-
words = set()
462456
# _LOGGER.debug(f"complete_request code={code}, posn={posn}, root={root}, words={words}")
463457
content = {
464458
"status": "ok",
@@ -708,8 +702,8 @@ async def housekeep_run(self):
708702
async def startup_timeout(self):
709703
"""Shut down the session if nothing connects after 30 seconds."""
710704
await self.housekeep_q.put(["register", "startup_timeout", asyncio.current_task()])
711-
await asyncio.sleep(30)
712-
if self.task_cnt_max == 1:
705+
await asyncio.sleep(self.no_connect_timeout)
706+
if self.task_cnt_max <= 1:
713707
#
714708
# nothing started other than us, so shut down the session
715709
#

tests/test_jupyter.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def msg_id():
3232
return str(uuid.uuid4())
3333

3434

35-
def msg_sign(msg_lst):
35+
def msg_sign(msg_lst, secret_key=SECRET_KEY):
3636
"""Sign a message with a secure signature."""
37-
auth_hmac = hmac.HMAC(SECRET_KEY, digestmod=hashlib.sha256)
37+
auth_hmac = hmac.HMAC(secret_key, digestmod=hashlib.sha256)
3838
for msg in msg_lst:
3939
auth_hmac.update(msg)
4040
return str_to_bytes(auth_hmac.hexdigest())
@@ -73,27 +73,27 @@ def new_header(msg_type):
7373

7474

7575
async def send(
76-
zmq_sock, msg_type, content=None, parent_header=None, metadata=None, identities=None,
76+
zmq_sock,
77+
msg_type,
78+
content=None,
79+
parent_header=None,
80+
metadata=None,
81+
identities=None,
82+
secret_key=SECRET_KEY,
7783
):
7884
"""Send message to the Jupyter client."""
7985
header = new_header(msg_type)
80-
if content is None:
81-
content = {}
82-
if parent_header is None:
83-
parent_header = {}
84-
if metadata is None:
85-
metadata = {}
8686

8787
def encode(msg):
8888
return str_to_bytes(json.dumps(msg))
8989

9090
msg_lst = [
9191
encode(header),
92-
encode(parent_header),
93-
encode(metadata),
94-
encode(content),
92+
encode(parent_header if parent_header else {}),
93+
encode(metadata if metadata else {}),
94+
encode(content if content else {}),
9595
]
96-
signature = msg_sign(msg_lst)
96+
signature = msg_sign(msg_lst, secret_key=secret_key)
9797
parts = [DELIM, signature, msg_lst[0], msg_lst[1], msg_lst[2], msg_lst[3]]
9898
if identities:
9999
parts = identities + parts
@@ -106,7 +106,7 @@ def encode(msg):
106106
PORT_NAMES = ["hb_port", "stdin_port", "shell_port", "iopub_port", "control_port"]
107107

108108

109-
async def setup_script(hass, now, source):
109+
async def setup_script(hass, now, source, no_connect=False):
110110
"""Initialize and load the given pyscript."""
111111

112112
scripts = [
@@ -149,6 +149,8 @@ def return_next_time():
149149
"signature_scheme": "hmac-sha256",
150150
"state_var": kernel_state_var,
151151
}
152+
if no_connect:
153+
kernel_cfg["no_connect_timeout"] = 0.0
152154
await hass.services.async_call("pyscript", "jupyter_kernel_start", kernel_cfg)
153155

154156
while True:
@@ -160,6 +162,10 @@ def return_next_time():
160162
port_nums = json.loads(ports_state.state)
161163

162164
sock = {}
165+
166+
if no_connect:
167+
return sock, port_nums
168+
163169
for name in PORT_NAMES:
164170
kernel_reader, kernel_writer = await asyncio.open_connection("127.0.0.1", port_nums[name])
165171
sock[name] = ZmqSocket(kernel_reader, kernel_writer, "ROUTER")
@@ -314,6 +320,10 @@ async def test_jupyter_kernel_msgs(hass, caplog):
314320
assert reply["header"]["msg_type"] == "is_complete_reply"
315321
assert reply["content"]["status"] == "invalid"
316322

323+
reply = await shell_msg(sock, "is_complete_request", {"code": "if 1:\n"})
324+
assert reply["header"]["msg_type"] == "is_complete_reply"
325+
assert reply["content"]["status"] == "incomplete"
326+
317327
#
318328
# test code execution
319329
#
@@ -398,7 +408,23 @@ async def test_jupyter_kernel_port_close(hass, caplog):
398408
assert await sock["hb_port"].recv() == msg
399409

400410
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
401-
await shutdown(sock)
411+
412+
#
413+
# shut down the session via signature mismatch with bad key
414+
#
415+
await send(
416+
sock["control_port"], "shutdown_request", {}, parent_header={}, identities={}, secret_key=b"bad_key"
417+
)
418+
419+
#
420+
# wait until the session ends, so the log receives the error message we check below
421+
#
422+
try:
423+
await sock["iopub_port"].recv()
424+
except EOFError:
425+
pass
426+
427+
assert "signature mismatch: check_sig=" in caplog.text
402428

403429

404430
async def test_jupyter_kernel_redefine_func(hass, caplog):
@@ -474,3 +500,10 @@ async def test_jupyter_kernel_stdout(hass, caplog):
474500
assert stdout_msg["content"]["text"] == "hello\n"
475501

476502
await shutdown(sock)
503+
504+
505+
async def test_jupyter_kernel_no_connection_timeout(hass, caplog):
506+
"""Test Jupyter kernel timeout on no connection."""
507+
sock, port_nums = await setup_script(hass, [dt(2020, 7, 1, 11, 0, 0, 0)], "", no_connect=True)
508+
509+
assert "No connections to session " in caplog.text

0 commit comments

Comments
 (0)