diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 207006d..da232ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,10 +16,9 @@ repos: - id: ruff-format - repo: local hooks: - - id: ty-check + - id: ty name: ty-check + entry: uv run ty check language: python - entry: ty check - pass_filenames: false - args: [--python=.venv/] - additional_dependencies: [ty] + types: [python] + pass_filenames: true diff --git a/tilebox-workflows/tests/runner/test_runner.py b/tilebox-workflows/tests/runner/test_runner.py index 2e28100..0945fe0 100644 --- a/tilebox-workflows/tests/runner/test_runner.py +++ b/tilebox-workflows/tests/runner/test_runner.py @@ -9,7 +9,7 @@ from tilebox.workflows import ExecutionContext, Task from tilebox.workflows.cache import InMemoryCache, JobCache from tilebox.workflows.client import Client -from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext +from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext, TaskState from tilebox.workflows.runner.task_runner import TaskRunner @@ -243,3 +243,79 @@ def test_runner_disallow_duplicate_task_identifiers() -> None: ), ): runner.register(ExplicitIdentifierTaskV2) + + +class OptionalSubbranch(Task): + def execute(self, context: ExecutionContext) -> None: + context.submit_subtask(OptionalSubtasks(False), optional=True) + context.submit_subtask(SucceedingTask()) + + +class OptionalSubtasks(Task): + failing_task_optional: bool + + def execute(self, context: ExecutionContext) -> None: + f = context.submit_subtask(FailingTask(), optional=self.failing_task_optional) + context.submit_subtask(SucceedingTask(), depends_on=[f]) + + +class FailingTask(Task): + def execute(self, context: ExecutionContext) -> None: + cache = context.job_cache # ty: ignore[unresolved-attribute] + cache["failing_task"] = b"1" # to make sure it actually ran + raise ValueError("This task always fails") + + +class SucceedingTask(Task): + def execute(self, context: ExecutionContext) -> None: + cache = context.job_cache # ty: ignore[unresolved-attribute] + cache["succeeding_task"] = b"1" # to make sure it actually ran + + +def test_runner_optional_subbranch() -> None: + client = replay_client("optional_subbranch.rpcs.bin") + job_client = client.jobs() + + with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock: + # we hardcode the trace parent for the job, which allows us to assert that every single outgoing request + # matches exactly byte for byte + get_trace_parent_mock.return_value = "00-42fe17a0cc6752adf16a5a326d37f51c-795dd6a3bc5a0b81-01" + job = client.jobs().submit("optional-subbranch-test", OptionalSubbranch()) + + cache = InMemoryCache() + runner = client.runner(tasks=[OptionalSubbranch, OptionalSubtasks, FailingTask, SucceedingTask], cache=cache) + + runner.run_all() + job = job_client.find(job) # load current job state + assert job.state == JobState.COMPLETED + + assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 3 + assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1 + assert job.execution_stats.tasks_by_state[TaskState.SKIPPED] == 1 + + assert cache.group(str(job.id))["failing_task"] == b"1" + assert cache.group(str(job.id))["succeeding_task"] == b"1" + + +def test_runner_optional_subtask() -> None: + client = replay_client("optional_subtask.rpcs.bin") + job_client = client.jobs() + + with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock: + # we hardcode the trace parent for the job, which allows us to assert that every single outgoing request + # matches exactly byte for byte + get_trace_parent_mock.return_value = "00-154ffe629cc5b746584825bfbb37963d-3ed10512af70309c-01" + job = client.jobs().submit("optional-subtasks-test", OptionalSubtasks(True)) + + cache = InMemoryCache() + runner = client.runner(tasks=[OptionalSubtasks, FailingTask, SucceedingTask], cache=cache) + + runner.run_all() + job = job_client.find(job) # load current job state + assert job.state == JobState.COMPLETED + + assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 2 + assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1 + + assert cache.group(str(job.id))["failing_task"] == b"1" + assert cache.group(str(job.id))["succeeding_task"] == b"1" diff --git a/tilebox-workflows/tests/runner/testdata/recordings/optional_subbranch.rpcs.bin b/tilebox-workflows/tests/runner/testdata/recordings/optional_subbranch.rpcs.bin new file mode 100644 index 0000000..f52efd2 --- /dev/null +++ b/tilebox-workflows/tests/runner/testdata/recordings/optional_subbranch.rpcs.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cac0cd1cd31ec5c62949704159a113d2f69a249c35997eac113a2ac695d196d5 +size 4430 diff --git a/tilebox-workflows/tests/runner/testdata/recordings/optional_subtask.rpcs.bin b/tilebox-workflows/tests/runner/testdata/recordings/optional_subtask.rpcs.bin new file mode 100644 index 0000000..51df487 --- /dev/null +++ b/tilebox-workflows/tests/runner/testdata/recordings/optional_subtask.rpcs.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3eee76fab8b6018fbba6e0332a8a5a193a1095ceab3e4a0b2e2b586a5d520f9b +size 3526 diff --git a/tilebox-workflows/tilebox/workflows/runner/task_runner.py b/tilebox-workflows/tilebox/workflows/runner/task_runner.py index a19b3c1..a264535 100644 --- a/tilebox-workflows/tilebox/workflows/runner/task_runner.py +++ b/tilebox-workflows/tilebox/workflows/runner/task_runner.py @@ -232,7 +232,7 @@ def _external_interrupt_handler(self, signum: int, frame: FrameType | None) -> N self._service.task_failed( self._task, RunnerShutdown("Task was interrupted"), - cancel_job=False, + was_workflow_error=False, progress_updates=progress, ) @@ -441,9 +441,11 @@ def _execute(self, task: Task, shutdown_context: _GracefulShutdown) -> Task | Id self.logger.exception(f"Task {task_repr} failed!") task_failed_retry = _retry_backoff(self._service.task_failed, stop=shutdown_context.stop_if_shutting_down()) - cancel_job = True - progress_updates = _finalize_mutable_progress_trackers(context._progress_indicators) # noqa: SLF001 - task_failed_retry(task, e, cancel_job, progress_updates) + was_workflow_error = True + progress_updates: list[ProgressIndicator] = _finalize_mutable_progress_trackers( + context._progress_indicators # noqa: SLF001 + ) + task_failed_retry(task, e, was_workflow_error, progress_updates) return None diff --git a/tilebox-workflows/tilebox/workflows/runner/task_service.py b/tilebox-workflows/tilebox/workflows/runner/task_service.py index 1283ee2..e9c8643 100644 --- a/tilebox-workflows/tilebox/workflows/runner/task_service.py +++ b/tilebox-workflows/tilebox/workflows/runner/task_service.py @@ -49,7 +49,7 @@ def next_task(self, task_to_run: NextTaskToRun | None, computed_task: ComputedTa return None def task_failed( - self, task: Task, error: Exception, cancel_job: bool, progress_updates: list[ProgressIndicator] + self, task: Task, error: Exception, was_workflow_error: bool, progress_updates: list[ProgressIndicator] ) -> None: # job ouptut is limited to 1KB, so truncate the error message if necessary error_message = repr(error)[: (1024 - len(task.display or "None") - 1)] @@ -57,7 +57,7 @@ def task_failed( request = TaskFailedRequest( task_id=uuid_to_uuid_message(task.id), - cancel_job=cancel_job, + was_workflow_error=was_workflow_error, display=display, progress_updates=[progress.to_message() for progress in progress_updates], ) diff --git a/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.py b/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.py index 9e88e83..9bea554 100644 --- a/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.py +++ b/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.py @@ -28,7 +28,7 @@ from tilebox.workflows.workflows.v1 import core_pb2 as workflows_dot_v1_dot_core__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17workflows/v1/task.proto\x12\x0cworkflows.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x13tilebox/v1/id.proto\x1a\x17workflows/v1/core.proto\"\xa6\x01\n\x0fNextTaskRequest\x12\x46\n\rcomputed_task\x18\x01 \x01(\x0b\x32\x1a.workflows.v1.ComputedTaskB\x05\xaa\x01\x02\x08\x01R\x0c\x63omputedTask\x12K\n\x10next_task_to_run\x18\x02 \x01(\x0b\x32\x1b.workflows.v1.NextTaskToRunB\x05\xaa\x01\x02\x08\x01R\rnextTaskToRun\"{\n\rNextTaskToRun\x12*\n\x0c\x63luster_slug\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01R\x0b\x63lusterSlug\x12>\n\x0bidentifiers\x18\x02 \x03(\x0b\x32\x1c.workflows.v1.TaskIdentifierR\x0bidentifiers\"\xaa\x02\n\x0c\x43omputedTask\x12&\n\x02id\x18\x01 \x01(\x0b\x32\x0e.tilebox.v1.IDB\x06\xbaH\x03\xc8\x01\x01R\x02id\x12\x18\n\x07\x64isplay\x18\x02 \x01(\tR\x07\x64isplay\x12:\n\tsub_tasks\x18\x05 \x01(\x0b\x32\x1d.workflows.v1.TaskSubmissionsR\x08subTasks\x12\x41\n\x10progress_updates\x18\x04 \x03(\x0b\x32\x16.workflows.v1.ProgressR\x0fprogressUpdates\x12Y\n\x10legacy_sub_tasks\x18\x03 \x03(\x0b\x32\".workflows.v1.SingleTaskSubmissionB\x0b\x18\x01\xbaH\x06\x92\x01\x03\x10\xe8\x07R\x0elegacySubTasks\"g\n\x0eIdlingResponse\x12U\n\x19suggested_idling_duration\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x17suggestedIdlingDuration\"\x93\x01\n\x10NextTaskResponse\x12/\n\tnext_task\x18\x01 \x01(\x0b\x32\x12.workflows.v1.TaskR\x08nextTask\x12\x34\n\x06idling\x18\x02 \x01(\x0b\x32\x1c.workflows.v1.IdlingResponseR\x06idling:\x18\xbaH\x15\"\x13\n\tnext_task\n\x06idling\"\xc0\x01\n\x11TaskFailedRequest\x12/\n\x07task_id\x18\x01 \x01(\x0b\x32\x0e.tilebox.v1.IDB\x06\xbaH\x03\xc8\x01\x01R\x06taskId\x12\x18\n\x07\x64isplay\x18\x02 \x01(\tR\x07\x64isplay\x12\x1d\n\ncancel_job\x18\x03 \x01(\x08R\tcancelJob\x12\x41\n\x10progress_updates\x18\x04 \x03(\x0b\x32\x16.workflows.v1.ProgressR\x0fprogressUpdates\"B\n\x11TaskStateResponse\x12-\n\x05state\x18\x01 \x01(\x0e\x32\x17.workflows.v1.TaskStateR\x05state\"\x87\x01\n\x10TaskLeaseRequest\x12/\n\x07task_id\x18\x01 \x01(\x0b\x32\x0e.tilebox.v1.IDB\x06\xbaH\x03\xc8\x01\x01R\x06taskId\x12\x42\n\x0frequested_lease\x18\x02 \x01(\x0b\x32\x19.google.protobuf.DurationR\x0erequestedLease2\xf4\x01\n\x0bTaskService\x12I\n\x08NextTask\x12\x1d.workflows.v1.NextTaskRequest\x1a\x1e.workflows.v1.NextTaskResponse\x12N\n\nTaskFailed\x12\x1f.workflows.v1.TaskFailedRequest\x1a\x1f.workflows.v1.TaskStateResponse\x12J\n\x0f\x45xtendTaskLease\x12\x1e.workflows.v1.TaskLeaseRequest\x1a\x17.workflows.v1.TaskLeaseBs\n\x10\x63om.workflows.v1B\tTaskProtoP\x01\xa2\x02\x03WXX\xaa\x02\x0cWorkflows.V1\xca\x02\x0cWorkflows\\V1\xe2\x02\x18Workflows\\V1\\GPBMetadata\xea\x02\rWorkflows::V1\x92\x03\x02\x08\x02\x62\x08\x65\x64itionsp\xe8\x07') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17workflows/v1/task.proto\x12\x0cworkflows.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x13tilebox/v1/id.proto\x1a\x17workflows/v1/core.proto\"\xa6\x01\n\x0fNextTaskRequest\x12\x46\n\rcomputed_task\x18\x01 \x01(\x0b\x32\x1a.workflows.v1.ComputedTaskB\x05\xaa\x01\x02\x08\x01R\x0c\x63omputedTask\x12K\n\x10next_task_to_run\x18\x02 \x01(\x0b\x32\x1b.workflows.v1.NextTaskToRunB\x05\xaa\x01\x02\x08\x01R\rnextTaskToRun\"{\n\rNextTaskToRun\x12*\n\x0c\x63luster_slug\x18\x01 \x01(\tB\x07\xbaH\x04r\x02\x10\x01R\x0b\x63lusterSlug\x12>\n\x0bidentifiers\x18\x02 \x03(\x0b\x32\x1c.workflows.v1.TaskIdentifierR\x0bidentifiers\"\xaa\x02\n\x0c\x43omputedTask\x12&\n\x02id\x18\x01 \x01(\x0b\x32\x0e.tilebox.v1.IDB\x06\xbaH\x03\xc8\x01\x01R\x02id\x12\x18\n\x07\x64isplay\x18\x02 \x01(\tR\x07\x64isplay\x12:\n\tsub_tasks\x18\x05 \x01(\x0b\x32\x1d.workflows.v1.TaskSubmissionsR\x08subTasks\x12\x41\n\x10progress_updates\x18\x04 \x03(\x0b\x32\x16.workflows.v1.ProgressR\x0fprogressUpdates\x12Y\n\x10legacy_sub_tasks\x18\x03 \x03(\x0b\x32\".workflows.v1.SingleTaskSubmissionB\x0b\x18\x01\xbaH\x06\x92\x01\x03\x10\xe8\x07R\x0elegacySubTasks\"g\n\x0eIdlingResponse\x12U\n\x19suggested_idling_duration\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x17suggestedIdlingDuration\"\x93\x01\n\x10NextTaskResponse\x12/\n\tnext_task\x18\x01 \x01(\x0b\x32\x12.workflows.v1.TaskR\x08nextTask\x12\x34\n\x06idling\x18\x02 \x01(\x0b\x32\x1c.workflows.v1.IdlingResponseR\x06idling:\x18\xbaH\x15\"\x13\n\tnext_task\n\x06idling\"\xcf\x01\n\x11TaskFailedRequest\x12/\n\x07task_id\x18\x01 \x01(\x0b\x32\x0e.tilebox.v1.IDB\x06\xbaH\x03\xc8\x01\x01R\x06taskId\x12\x18\n\x07\x64isplay\x18\x02 \x01(\tR\x07\x64isplay\x12,\n\x12was_workflow_error\x18\x03 \x01(\x08R\x10wasWorkflowError\x12\x41\n\x10progress_updates\x18\x04 \x03(\x0b\x32\x16.workflows.v1.ProgressR\x0fprogressUpdates\"B\n\x11TaskStateResponse\x12-\n\x05state\x18\x01 \x01(\x0e\x32\x17.workflows.v1.TaskStateR\x05state\"\x87\x01\n\x10TaskLeaseRequest\x12/\n\x07task_id\x18\x01 \x01(\x0b\x32\x0e.tilebox.v1.IDB\x06\xbaH\x03\xc8\x01\x01R\x06taskId\x12\x42\n\x0frequested_lease\x18\x02 \x01(\x0b\x32\x19.google.protobuf.DurationR\x0erequestedLease2\xf4\x01\n\x0bTaskService\x12I\n\x08NextTask\x12\x1d.workflows.v1.NextTaskRequest\x1a\x1e.workflows.v1.NextTaskResponse\x12N\n\nTaskFailed\x12\x1f.workflows.v1.TaskFailedRequest\x1a\x1f.workflows.v1.TaskStateResponse\x12J\n\x0f\x45xtendTaskLease\x12\x1e.workflows.v1.TaskLeaseRequest\x1a\x17.workflows.v1.TaskLeaseBs\n\x10\x63om.workflows.v1B\tTaskProtoP\x01\xa2\x02\x03WXX\xaa\x02\x0cWorkflows.V1\xca\x02\x0cWorkflows\\V1\xe2\x02\x18Workflows\\V1\\GPBMetadata\xea\x02\rWorkflows::V1\x92\x03\x02\x08\x02\x62\x08\x65\x64itionsp\xe8\x07') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -63,11 +63,11 @@ _globals['_NEXTTASKRESPONSE']._serialized_start=849 _globals['_NEXTTASKRESPONSE']._serialized_end=996 _globals['_TASKFAILEDREQUEST']._serialized_start=999 - _globals['_TASKFAILEDREQUEST']._serialized_end=1191 - _globals['_TASKSTATERESPONSE']._serialized_start=1193 - _globals['_TASKSTATERESPONSE']._serialized_end=1259 - _globals['_TASKLEASEREQUEST']._serialized_start=1262 - _globals['_TASKLEASEREQUEST']._serialized_end=1397 - _globals['_TASKSERVICE']._serialized_start=1400 - _globals['_TASKSERVICE']._serialized_end=1644 + _globals['_TASKFAILEDREQUEST']._serialized_end=1206 + _globals['_TASKSTATERESPONSE']._serialized_start=1208 + _globals['_TASKSTATERESPONSE']._serialized_end=1274 + _globals['_TASKLEASEREQUEST']._serialized_start=1277 + _globals['_TASKLEASEREQUEST']._serialized_end=1412 + _globals['_TASKSERVICE']._serialized_start=1415 + _globals['_TASKSERVICE']._serialized_end=1659 # @@protoc_insertion_point(module_scope) diff --git a/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.pyi b/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.pyi index 4931d22..15f6770 100644 --- a/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.pyi +++ b/tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.pyi @@ -55,16 +55,16 @@ class NextTaskResponse(_message.Message): def __init__(self, next_task: _Optional[_Union[_core_pb2.Task, _Mapping]] = ..., idling: _Optional[_Union[IdlingResponse, _Mapping]] = ...) -> None: ... class TaskFailedRequest(_message.Message): - __slots__ = ("task_id", "display", "cancel_job", "progress_updates") + __slots__ = ("task_id", "display", "was_workflow_error", "progress_updates") TASK_ID_FIELD_NUMBER: _ClassVar[int] DISPLAY_FIELD_NUMBER: _ClassVar[int] - CANCEL_JOB_FIELD_NUMBER: _ClassVar[int] + WAS_WORKFLOW_ERROR_FIELD_NUMBER: _ClassVar[int] PROGRESS_UPDATES_FIELD_NUMBER: _ClassVar[int] task_id: _id_pb2.ID display: str - cancel_job: bool + was_workflow_error: bool progress_updates: _containers.RepeatedCompositeFieldContainer[_core_pb2.Progress] - def __init__(self, task_id: _Optional[_Union[_id_pb2.ID, _Mapping]] = ..., display: _Optional[str] = ..., cancel_job: bool = ..., progress_updates: _Optional[_Iterable[_Union[_core_pb2.Progress, _Mapping]]] = ...) -> None: ... + def __init__(self, task_id: _Optional[_Union[_id_pb2.ID, _Mapping]] = ..., display: _Optional[str] = ..., was_workflow_error: bool = ..., progress_updates: _Optional[_Iterable[_Union[_core_pb2.Progress, _Mapping]]] = ...) -> None: ... class TaskStateResponse(_message.Message): __slots__ = ("state",)