Skip to content

Commit e95261f

Browse files
committed
moved trigger task cancellation to reaper task
1 parent 56271cd commit e95261f

File tree

4 files changed

+68
-99
lines changed

4 files changed

+68
-99
lines changed

custom_components/pyscript/function.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .const import LOGGER_PATH
99

10-
_LOGGER = logging.getLogger(LOGGER_PATH + ".handler")
10+
_LOGGER = logging.getLogger(LOGGER_PATH + ".function")
1111

1212

1313
class Function:
@@ -73,18 +73,22 @@ def init(cls, hass):
7373
# start a task which is a reaper for canceled tasks, since some # functions
7474
# like TrigInfo.stop() can't be async (it's called from a __del__ method)
7575
#
76-
async def task_await_reaper(reaper_q):
77-
try:
78-
while True:
79-
task = await reaper_q.get()
80-
await task
81-
except asyncio.CancelledError:
82-
raise
83-
except Exception:
84-
pass
76+
async def task_cancel_reaper(reaper_q):
77+
while True:
78+
try:
79+
try:
80+
task = await reaper_q.get()
81+
task.cancel()
82+
await task
83+
except asyncio.CancelledError:
84+
pass
85+
except asyncio.CancelledError:
86+
raise
87+
except Exception:
88+
_LOGGER.error("task_cancel_reaper: got exception %s", traceback.format_exc(-1))
8589

8690
cls.task_reaper_q = asyncio.Queue(0)
87-
cls.task_await_repeaer = Function.create_task(task_await_reaper(cls.task_reaper_q))
91+
cls.task_cancel_repeaer = Function.create_task(task_cancel_reaper(cls.task_reaper_q))
8892

8993
@classmethod
9094
async def async_sleep(cls, duration):
@@ -101,20 +105,13 @@ async def task_unique(cls, name, kill_me=False):
101105
"""Implement task.unique()."""
102106
if name in cls.unique_name2task:
103107
if kill_me:
104-
task = asyncio.current_task()
105-
106-
# it seems we need to use another task to cancel ourselves
107-
# I'm sure there is a better way to cancel ourselves...
108-
async def cancel_self():
109-
try:
110-
task.cancel()
111-
await task
112-
except asyncio.CancelledError:
113-
pass
114-
115-
asyncio.create_task(cancel_self())
116-
# ugh - wait to be canceled
117-
await asyncio.sleep(10000)
108+
#
109+
# it seems we can't cancel ourselves, so we
110+
# tell the repeaer task to cancel us
111+
#
112+
Function.task_cancel(asyncio.current_task())
113+
# wait to be canceled
114+
await asyncio.sleep(100000)
118115
else:
119116
task = cls.unique_name2task[name]
120117
if task in cls.our_tasks:
@@ -216,14 +213,14 @@ async def run_coro(cls, coro):
216213
#
217214
# Add a placeholder for the new task so we know it's one we started
218215
#
219-
task = asyncio.current_task()
220-
cls.our_tasks.add(task)
221216
try:
217+
task = asyncio.current_task()
218+
cls.our_tasks.add(task)
222219
await coro
223220
except asyncio.CancelledError:
224221
raise
225222
except Exception:
226-
_LOGGER.error("run_coro: %s", traceback.format_exc(-1))
223+
_LOGGER.error("run_coro: got exception %s", traceback.format_exc(-1))
227224
finally:
228225
if task in cls.unique_task2name:
229226
del cls.unique_name2task[cls.unique_task2name[task]]
@@ -236,6 +233,6 @@ def create_task(cls, coro):
236233
return cls.hass.loop.create_task(cls.run_coro(coro))
237234

238235
@classmethod
239-
def task_await_send(cls, task):
240-
"""Send a task that has been canceled to the reaper task so it can be awaited."""
236+
def task_cancel(cls, task):
237+
"""Send a task to be canceled by the reaper."""
241238
cls.task_reaper_q.put_nowait(task)

custom_components/pyscript/trigger.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ async def wait_until(
221221
Event.notify_del(event_trigger[0], notify_q)
222222
if exc:
223223
raise exc
224-
_LOGGER.debug("trigger %s wait_until returning %s", ast_ctx.name, ret)
225224
return ret
226225

227226
@classmethod
@@ -461,8 +460,6 @@ def __init__(
461460
self.setup_ok = False
462461
self.run_on_startup = False
463462

464-
_LOGGER.debug("trigger %s event_trigger = %s", self.name, self.event_trigger)
465-
466463
if self.state_active is not None:
467464
self.active_expr = AstEval(
468465
f"{self.name} @state_active()", self.global_ctx, logger_name=self.name
@@ -519,19 +516,13 @@ def stop(self):
519516
if self.event_trigger is not None:
520517
Event.notify_del(self.event_trigger[0], self.notify_q)
521518
if self.task:
522-
try:
523-
self.task.cancel()
524-
Function.task_await_send(self.task)
525-
except asyncio.CancelledError:
526-
pass
527-
self.task = None
528-
_LOGGER.debug("trigger %s is stopped", self.name)
519+
Function.task_cancel(self.task)
529520

530521
def start(self):
531522
"""Start this trigger task."""
532523
if not self.task and self.setup_ok:
533524
self.task = Function.create_task(self.trigger_watch())
534-
_LOGGER.debug("trigger %s is active", self.name)
525+
_LOGGER.debug("trigger %s is active %s", self.name, self.task)
535526

536527
async def trigger_watch(self):
537528
"""Task that runs for each trigger, waiting for the next trigger and calling the function."""
@@ -655,7 +646,7 @@ async def do_func_call(func, ast_ctx, task_unique, kwargs=None):
655646
func_args,
656647
)
657648
Function.create_task(
658-
do_func_call(self.action, self.action_ast_ctx, self.task_unique, kwargs=func_args,)
649+
do_func_call(self.action, self.action_ast_ctx, self.task_unique, kwargs=func_args)
659650
)
660651

661652
except asyncio.CancelledError:

docs/reference.rst

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -247,21 +247,21 @@ In ``@time_trigger``, each string specification ``time_spec`` can take one of fo
247247
according to Linux-style crontab. Each of the five entries are separated by spaces and correspond
248248
to minutes, hours, day-of-month, month, day-of-week (0 = sunday):
249249

250-
============ ==============
251-
field allowed values
252-
============ ==============
253-
minute 0-59
254-
hour 0-23
255-
day of month 1-31
256-
month 1-12
257-
day of week 0-6 (0 is Sun)
258-
============ ==============
259-
260-
Each field can be a ``*`` (which means “all”), a single number, a range or comma-separated list of
261-
numbers or ranges (no spaces). Ranges are inclusive. For example, if you specify hours as
262-
``6,10-13`` that means hours of 6,10,11,12,13. The trigger happens on the next minute, hour, day
263-
that matches the specification. See any Linux documentation for examples and more details (note:
264-
names for days of week and months are not supported; only their integer values are).
250+
============ ==============
251+
field allowed values
252+
============ ==============
253+
minute 0-59
254+
hour 0-23
255+
day of month 1-31
256+
month 1-12
257+
day of week 0-6 (0 is Sun)
258+
============ ==============
259+
260+
Each field can be a ``*`` (which means “all”), a single number, a range or comma-separated list of
261+
numbers or ranges (no spaces). Ranges are inclusive. For example, if you specify hours as
262+
``6,10-13`` that means hours of 6,10,11,12,13. The trigger happens on the next minute, hour, day
263+
that matches the specification. See any Linux documentation for examples and more details (note:
264+
names for days of week and months are not supported; only their integer values are).
265265

266266
When the ``@time_trigger`` occurs and the function is called, the keyword argument ``trigger_type``
267267
is set to ``"time"``, and ``trigger_time`` is the exact ``datetime`` of the time specification that

tests/test_function.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def func_trig(var_name=None, value=None):
513513
514514
return func_trig
515515
516-
f = [factory(50), factory(51), factory(52), factory(53)]
516+
f = [factory(50), factory(51), factory(52), factory(53), factory(54)]
517517
""",
518518
)
519519
seq_num = 0
@@ -524,49 +524,30 @@ def func_trig(var_name=None, value=None):
524524
assert literal_eval(await wait_until_done(notify_q)) == seq_num
525525

526526
#
527-
# trigger them one at a time
527+
# trigger them one at a time to make sure each is working
528528
#
529-
for i in range(3):
529+
for i in range(5):
530530
seq_num += 1
531531
hass.states.async_set("pyscript.var1", 50 + i)
532532
assert literal_eval(await wait_until_done(notify_q)) == seq_num
533533

534-
#
535-
# trigger all three together; we don't know the order, so just check
536-
# we got all 3
537-
#
538-
hass.states.async_set("pyscript.var1", 100)
539-
seqs = set()
540-
expect = set()
541-
for i in range(4):
542-
seqs.add(literal_eval(await wait_until_done(notify_q)))
543-
expect.add(seq_num + 50 + i)
544-
assert seqs == expect
545-
546-
#
547-
# now trigger all again, but just the first deletes the last
548-
# trigger function and replies
549-
#
550-
seq_num += 1
551-
hass.states.async_set("pyscript.var1", 101)
552-
assert literal_eval(await wait_until_done(notify_q)) == seq_num
553-
554-
#
555-
# now trigger all again, and confirm we only get two
556-
#
557-
hass.states.async_set("pyscript.var1", 100)
558-
seqs = set()
559-
expect = set()
560-
for i in range(3):
561-
seqs.add(literal_eval(await wait_until_done(notify_q)))
562-
expect.add(seq_num + 50 + i)
563-
assert seqs == expect
564-
565-
#
566-
# now trigger all again, but just the first deletes the last
567-
# trigger function and replies, just to make sure the last
568-
# one didn't trigger
569-
#
570-
seq_num += 1
571-
hass.states.async_set("pyscript.var1", 101)
572-
assert literal_eval(await wait_until_done(notify_q)) == seq_num
534+
for num_func in range(5, 1, -1):
535+
#
536+
# trigger all together; we don't know the order, so just check
537+
# we got all of the remaining ones
538+
#
539+
hass.states.async_set("pyscript.var1", 100)
540+
seqs = set()
541+
expect = set()
542+
for i in range(num_func):
543+
seqs.add(literal_eval(await wait_until_done(notify_q)))
544+
expect.add(seq_num + 50 + i)
545+
assert seqs == expect
546+
547+
#
548+
# now trigger all again, but just the first deletes the last
549+
# trigger function and replies with the next seq number
550+
#
551+
seq_num += 1
552+
hass.states.async_set("pyscript.var1", 101)
553+
assert literal_eval(await wait_until_done(notify_q)) == seq_num

0 commit comments

Comments
 (0)