@@ -32,9 +32,9 @@ def msg_id():
3232 return str (uuid .uuid4 ())
3333
3434
35- def msg_sign (msg_lst ):
35+ def msg_sign (msg_lst , secret_key = SECRET_KEY ):
3636 """Sign a message with a secure signature."""
37- auth_hmac = hmac .HMAC (SECRET_KEY , digestmod = hashlib .sha256 )
37+ auth_hmac = hmac .HMAC (secret_key , digestmod = hashlib .sha256 )
3838 for msg in msg_lst :
3939 auth_hmac .update (msg )
4040 return str_to_bytes (auth_hmac .hexdigest ())
@@ -73,27 +73,27 @@ def new_header(msg_type):
7373
7474
7575async def send (
76- zmq_sock , msg_type , content = None , parent_header = None , metadata = None , identities = None ,
76+ zmq_sock ,
77+ msg_type ,
78+ content = None ,
79+ parent_header = None ,
80+ metadata = None ,
81+ identities = None ,
82+ secret_key = SECRET_KEY ,
7783):
7884 """Send message to the Jupyter client."""
7985 header = new_header (msg_type )
80- if content is None :
81- content = {}
82- if parent_header is None :
83- parent_header = {}
84- if metadata is None :
85- metadata = {}
8686
8787 def encode (msg ):
8888 return str_to_bytes (json .dumps (msg ))
8989
9090 msg_lst = [
9191 encode (header ),
92- encode (parent_header ),
93- encode (metadata ),
94- encode (content ),
92+ encode (parent_header if parent_header else {} ),
93+ encode (metadata if metadata else {} ),
94+ encode (content if content else {} ),
9595 ]
96- signature = msg_sign (msg_lst )
96+ signature = msg_sign (msg_lst , secret_key = secret_key )
9797 parts = [DELIM , signature , msg_lst [0 ], msg_lst [1 ], msg_lst [2 ], msg_lst [3 ]]
9898 if identities :
9999 parts = identities + parts
@@ -106,7 +106,7 @@ def encode(msg):
106106PORT_NAMES = ["hb_port" , "stdin_port" , "shell_port" , "iopub_port" , "control_port" ]
107107
108108
109- async def setup_script (hass , now , source ):
109+ async def setup_script (hass , now , source , no_connect = False ):
110110 """Initialize and load the given pyscript."""
111111
112112 scripts = [
@@ -149,6 +149,8 @@ def return_next_time():
149149 "signature_scheme" : "hmac-sha256" ,
150150 "state_var" : kernel_state_var ,
151151 }
152+ if no_connect :
153+ kernel_cfg ["no_connect_timeout" ] = 0.0
152154 await hass .services .async_call ("pyscript" , "jupyter_kernel_start" , kernel_cfg )
153155
154156 while True :
@@ -160,6 +162,10 @@ def return_next_time():
160162 port_nums = json .loads (ports_state .state )
161163
162164 sock = {}
165+
166+ if no_connect :
167+ return sock , port_nums
168+
163169 for name in PORT_NAMES :
164170 kernel_reader , kernel_writer = await asyncio .open_connection ("127.0.0.1" , port_nums [name ])
165171 sock [name ] = ZmqSocket (kernel_reader , kernel_writer , "ROUTER" )
@@ -314,6 +320,10 @@ async def test_jupyter_kernel_msgs(hass, caplog):
314320 assert reply ["header" ]["msg_type" ] == "is_complete_reply"
315321 assert reply ["content" ]["status" ] == "invalid"
316322
323+ reply = await shell_msg (sock , "is_complete_request" , {"code" : "if 1:\n " })
324+ assert reply ["header" ]["msg_type" ] == "is_complete_reply"
325+ assert reply ["content" ]["status" ] == "incomplete"
326+
317327 #
318328 # test code execution
319329 #
@@ -398,7 +408,23 @@ async def test_jupyter_kernel_port_close(hass, caplog):
398408 assert await sock ["hb_port" ].recv () == msg
399409
400410 hass .bus .async_fire (EVENT_HOMEASSISTANT_STARTED )
401- await shutdown (sock )
411+
412+ #
413+ # shut down the session via signature mismatch with bad key
414+ #
415+ await send (
416+ sock ["control_port" ], "shutdown_request" , {}, parent_header = {}, identities = {}, secret_key = b"bad_key"
417+ )
418+
419+ #
420+ # wait until the session ends, so the log receives the error message we check below
421+ #
422+ try :
423+ await sock ["iopub_port" ].recv ()
424+ except EOFError :
425+ pass
426+
427+ assert "signature mismatch: check_sig=" in caplog .text
402428
403429
404430async def test_jupyter_kernel_redefine_func (hass , caplog ):
@@ -474,3 +500,10 @@ async def test_jupyter_kernel_stdout(hass, caplog):
474500 assert stdout_msg ["content" ]["text" ] == "hello\n "
475501
476502 await shutdown (sock )
503+
504+
505+ async def test_jupyter_kernel_no_connection_timeout (hass , caplog ):
506+ """Test Jupyter kernel timeout on no connection."""
507+ sock , port_nums = await setup_script (hass , [dt (2020 , 7 , 1 , 11 , 0 , 0 , 0 )], "" , no_connect = True )
508+
509+ assert "No connections to session " in caplog .text
0 commit comments