Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ repos:
- id: ruff-format
- repo: local
hooks:
- id: ty-check
- id: ty
name: ty-check
entry: uv run ty check
language: python
entry: ty check
pass_filenames: false
args: [--python=.venv/]
additional_dependencies: [ty]
types: [python]
pass_filenames: true
78 changes: 77 additions & 1 deletion tilebox-workflows/tests/runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tilebox.workflows import ExecutionContext, Task
from tilebox.workflows.cache import InMemoryCache, JobCache
from tilebox.workflows.client import Client
from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext
from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext, TaskState
from tilebox.workflows.runner.task_runner import TaskRunner


Expand Down Expand Up @@ -243,3 +243,79 @@ def test_runner_disallow_duplicate_task_identifiers() -> None:
),
):
runner.register(ExplicitIdentifierTaskV2)


class OptionalSubbranch(Task):
def execute(self, context: ExecutionContext) -> None:
context.submit_subtask(OptionalSubtasks(False), optional=True)
context.submit_subtask(SucceedingTask())


class OptionalSubtasks(Task):
failing_task_optional: bool

def execute(self, context: ExecutionContext) -> None:
f = context.submit_subtask(FailingTask(), optional=self.failing_task_optional)
context.submit_subtask(SucceedingTask(), depends_on=[f])


class FailingTask(Task):
def execute(self, context: ExecutionContext) -> None:
cache = context.job_cache # ty: ignore[unresolved-attribute]
cache["failing_task"] = b"1" # to make sure it actually ran
raise ValueError("This task always fails")


class SucceedingTask(Task):
def execute(self, context: ExecutionContext) -> None:
cache = context.job_cache # ty: ignore[unresolved-attribute]
cache["succeeding_task"] = b"1" # to make sure it actually ran


def test_runner_optional_subbranch() -> None:
client = replay_client("optional_subbranch.rpcs.bin")
job_client = client.jobs()

with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock:
# we hardcode the trace parent for the job, which allows us to assert that every single outgoing request
# matches exactly byte for byte
get_trace_parent_mock.return_value = "00-42fe17a0cc6752adf16a5a326d37f51c-795dd6a3bc5a0b81-01"
job = client.jobs().submit("optional-subbranch-test", OptionalSubbranch())

cache = InMemoryCache()
runner = client.runner(tasks=[OptionalSubbranch, OptionalSubtasks, FailingTask, SucceedingTask], cache=cache)

runner.run_all()
job = job_client.find(job) # load current job state
assert job.state == JobState.COMPLETED

assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 3
assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1
assert job.execution_stats.tasks_by_state[TaskState.SKIPPED] == 1

assert cache.group(str(job.id))["failing_task"] == b"1"
assert cache.group(str(job.id))["succeeding_task"] == b"1"


def test_runner_optional_subtask() -> None:
client = replay_client("optional_subtask.rpcs.bin")
job_client = client.jobs()

with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock:
# we hardcode the trace parent for the job, which allows us to assert that every single outgoing request
# matches exactly byte for byte
get_trace_parent_mock.return_value = "00-154ffe629cc5b746584825bfbb37963d-3ed10512af70309c-01"
job = client.jobs().submit("optional-subtasks-test", OptionalSubtasks(True))

cache = InMemoryCache()
runner = client.runner(tasks=[OptionalSubtasks, FailingTask, SucceedingTask], cache=cache)

runner.run_all()
job = job_client.find(job) # load current job state
assert job.state == JobState.COMPLETED

assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 2
assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1

assert cache.group(str(job.id))["failing_task"] == b"1"
assert cache.group(str(job.id))["succeeding_task"] == b"1"
Git LFS file not shown
Git LFS file not shown
10 changes: 6 additions & 4 deletions tilebox-workflows/tilebox/workflows/runner/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _external_interrupt_handler(self, signum: int, frame: FrameType | None) -> N
self._service.task_failed(
self._task,
RunnerShutdown("Task was interrupted"),
cancel_job=False,
was_workflow_error=False,
progress_updates=progress,
)

Expand Down Expand Up @@ -441,9 +441,11 @@ def _execute(self, task: Task, shutdown_context: _GracefulShutdown) -> Task | Id
self.logger.exception(f"Task {task_repr} failed!")

task_failed_retry = _retry_backoff(self._service.task_failed, stop=shutdown_context.stop_if_shutting_down())
cancel_job = True
progress_updates = _finalize_mutable_progress_trackers(context._progress_indicators) # noqa: SLF001
task_failed_retry(task, e, cancel_job, progress_updates)
was_workflow_error = True
progress_updates: list[ProgressIndicator] = _finalize_mutable_progress_trackers(
context._progress_indicators # noqa: SLF001
)
task_failed_retry(task, e, was_workflow_error, progress_updates)

return None

Expand Down
4 changes: 2 additions & 2 deletions tilebox-workflows/tilebox/workflows/runner/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def next_task(self, task_to_run: NextTaskToRun | None, computed_task: ComputedTa
return None

def task_failed(
self, task: Task, error: Exception, cancel_job: bool, progress_updates: list[ProgressIndicator]
self, task: Task, error: Exception, was_workflow_error: bool, progress_updates: list[ProgressIndicator]
) -> None:
# job ouptut is limited to 1KB, so truncate the error message if necessary
error_message = repr(error)[: (1024 - len(task.display or "None") - 1)]
display = f"{task.display}" if error_message == "" else f"{task.display}\n{error_message}"

request = TaskFailedRequest(
task_id=uuid_to_uuid_message(task.id),
cancel_job=cancel_job,
was_workflow_error=was_workflow_error,
display=display,
progress_updates=[progress.to_message() for progress in progress_updates],
)
Expand Down
16 changes: 8 additions & 8 deletions tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ class NextTaskResponse(_message.Message):
def __init__(self, next_task: _Optional[_Union[_core_pb2.Task, _Mapping]] = ..., idling: _Optional[_Union[IdlingResponse, _Mapping]] = ...) -> None: ...

class TaskFailedRequest(_message.Message):
__slots__ = ("task_id", "display", "cancel_job", "progress_updates")
__slots__ = ("task_id", "display", "was_workflow_error", "progress_updates")
TASK_ID_FIELD_NUMBER: _ClassVar[int]
DISPLAY_FIELD_NUMBER: _ClassVar[int]
CANCEL_JOB_FIELD_NUMBER: _ClassVar[int]
WAS_WORKFLOW_ERROR_FIELD_NUMBER: _ClassVar[int]
PROGRESS_UPDATES_FIELD_NUMBER: _ClassVar[int]
task_id: _id_pb2.ID
display: str
cancel_job: bool
was_workflow_error: bool
progress_updates: _containers.RepeatedCompositeFieldContainer[_core_pb2.Progress]
def __init__(self, task_id: _Optional[_Union[_id_pb2.ID, _Mapping]] = ..., display: _Optional[str] = ..., cancel_job: bool = ..., progress_updates: _Optional[_Iterable[_Union[_core_pb2.Progress, _Mapping]]] = ...) -> None: ...
def __init__(self, task_id: _Optional[_Union[_id_pb2.ID, _Mapping]] = ..., display: _Optional[str] = ..., was_workflow_error: bool = ..., progress_updates: _Optional[_Iterable[_Union[_core_pb2.Progress, _Mapping]]] = ...) -> None: ...

class TaskStateResponse(_message.Message):
__slots__ = ("state",)
Expand Down
Loading