diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 9c22c05ce..125f2a373 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1739,10 +1739,13 @@ async def workflow_sleep( else None ) fut = self.create_future() - self._timer_impl( + timer_handle = self._timer_impl( duration, _TimerOptions(user_metadata=user_metadata), - lambda: fut.set_result(None), + lambda: fut.set_result(None) if not fut.done() else None, + ) + fut.add_done_callback( + lambda f: timer_handle.cancel() if f.cancelled() else None ) await fut diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index deedae964..068716e3f 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3431,6 +3431,66 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task( await result_task +@workflow.defn +class CancelWorkflowSleepTaskWorkflow: + """Like CancelSignalAndTimerFiredInSameTaskWorkflow but uses workflow.sleep.""" + + _ready = False + timer_task: asyncio.Task[None] # type: ignore[reportUninitializedInstanceVariable] + + @workflow.run + async def run(self) -> str: + self.timer_task = asyncio.create_task(workflow.sleep(60 * 60)) + self._ready = True + try: + await self.timer_task + return "timer_completed" + except asyncio.CancelledError: + return "timer_cancelled" + + @workflow.query + def ready(self) -> bool: + return self._ready + + @workflow.signal + def cancel_timer(self) -> None: + self.timer_task.cancel() + + +async def test_workflow_sleep_task_cancellation( + client: Client, +): + async with new_worker( + client, + CancelWorkflowSleepTaskWorkflow, + ) as worker: + handle = await client.start_workflow( + CancelWorkflowSleepTaskWorkflow.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + async def ready() -> bool: + return await handle.query(CancelWorkflowSleepTaskWorkflow.ready) + + await assert_eq_eventually(True, ready) + await handle.signal(CancelWorkflowSleepTaskWorkflow.cancel_timer) + result = await handle.result() + + assert result == "timer_cancelled" + # Verify the Temporal timer was actually cancelled on the server + resp = await client.workflow_service.get_workflow_execution_history( + GetWorkflowExecutionHistoryRequest( + namespace=client.namespace, + execution=WorkflowExecution(workflow_id=handle.id), + ) + ) + timer_canceled = any( + e.event_type == EventType.EVENT_TYPE_TIMER_CANCELED for e in resp.history.events + ) + assert timer_canceled, "Expected TimerCanceled event in history" + + class MyCustomError(ApplicationError): def __init__(self, message: str) -> None: super().__init__(message, type="MyCustomError", non_retryable=True)