Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f607d24
adding download task and creating seperate download pool
HansVRP Nov 26, 2025
7478990
include initial unit testing
HansVRP Nov 28, 2025
8e3ae8b
updated unit tests
HansVRP Dec 10, 2025
24989cf
including two simple unit tests and unifying pool usage
HansVRP Dec 11, 2025
3293327
changes to job manager
HansVRP Dec 11, 2025
12277de
adding easy callback to check number of pending tasks on thread worke…
HansVRP Dec 11, 2025
bade858
process updates through job update loop
HansVRP Dec 11, 2025
14585c9
remove folder creation logic from thread to resprect optional downloa…
HansVRP Dec 11, 2025
d0a7fbf
fix stop_job_thread
HansVRP Dec 11, 2025
d4d0110
working on fix for indefinete loop
HansVRP Dec 11, 2025
086a30b
fix infinite loop
HansVRP Dec 11, 2025
2603c30
wrapper to abstract multiple threadpools
HansVRP Dec 15, 2025
188ab5d
coupling task type to seperate pool
HansVRP Dec 15, 2025
24cf000
include unit test for dict of pools
HansVRP Dec 15, 2025
8a3aa20
tmp_path usage and renaming
HansVRP Dec 16, 2025
2e0d008
fix documentation
HansVRP Dec 16, 2025
1894ef6
keep track of number of assets
HansVRP Dec 19, 2025
1d2020b
avoid abreviation of number
HansVRP Dec 19, 2025
62a4bf4
do not expose number of remaining jobs
HansVRP Dec 19, 2025
fdcd047
abstract task name in thread pool
HansVRP Dec 19, 2025
9b978d3
not use remaing in unit test
HansVRP Dec 19, 2025
d629637
fix unit tests
HansVRP Dec 19, 2025
2ab7741
fix
HansVRP Dec 19, 2025
c999d1f
move towards get_results to avoid deprecation
HansVRP Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions openeo/extra/job_management/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openeo.extra.job_management._thread_worker import (
_JobManagerWorkerThreadPool,
_JobStartTask,
_JobDownloadTask
)
from openeo.rest import OpenEoApiError
from openeo.rest.auth.auth import BearerAuth
Expand Down Expand Up @@ -175,6 +176,7 @@ def start_job(

.. versionchanged:: 0.47.0
Added ``download_results`` parameter.

"""

# Expected columns in the job DB dataframes.
Expand Down Expand Up @@ -373,6 +375,9 @@ def run_loop():
).values()
)
> 0

or (self._worker_pool.number_pending_tasks() > 0)

and not self._stop_thread
):
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
Expand All @@ -398,7 +403,10 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):

.. versionadded:: 0.32.0
"""
self._worker_pool.shutdown()
if self._worker_pool is not None or self._worker_pool.number_pending_tasks() > 0:
self._worker_pool.shutdown()
self._worker_pool = None


if self._thread is not None:
self._stop_thread = True
Expand Down Expand Up @@ -504,13 +512,15 @@ def run_jobs(

self._worker_pool = _JobManagerWorkerThreadPool()


while (
sum(
job_db.count_by_status(
statuses=["not_started", "created", "queued_for_start", "queued", "running"]
).values()
)
> 0
).values()) > 0

or (self._worker_pool.number_pending_tasks() > 0)

):
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
stats["run_jobs loop"] += 1
Expand All @@ -520,8 +530,10 @@ def run_jobs(
time.sleep(self.poll_sleep)
stats["sleep"] += 1

# TODO; run post process after shutdown once more to ensure completion?


self._worker_pool.shutdown()
self._worker_pool = None

return stats

Expand Down Expand Up @@ -567,7 +579,9 @@ def _job_update_loop(
stats["job_db persist"] += 1
total_added += 1

self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats)
if self._worker_pool is not None:
self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats)


# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
for job, row in jobs_done:
Expand All @@ -579,6 +593,7 @@ def _job_update_loop(
for job, row in jobs_cancel:
self.on_job_cancel(job, row)


def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None):
"""Helper method for launching jobs

Expand Down Expand Up @@ -643,7 +658,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
df_idx=i,
)
_log.info(f"Submitting task {task} to thread pool")
self._worker_pool.submit_task(task)
self._worker_pool.submit_task(task=task, pool_name="job_start")

stats["job_queued_for_start"] += 1
df.loc[i, "status"] = "queued_for_start"
Expand Down Expand Up @@ -689,7 +704,7 @@ def _process_threadworker_updates(
:param stats: Dictionary accumulating statistic counters
"""
# Retrieve completed task results immediately
results, _ = worker_pool.process_futures(timeout=0)
results = worker_pool.process_futures(timeout=0)

# Collect update dicts
updates: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -735,17 +750,28 @@ def on_job_done(self, job: BatchJob, row):
:param job: The job that has finished.
:param row: DataFrame row containing the job's metadata.
"""
# TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
if self._download_results:
job_metadata = job.describe()
job_dir = self.get_job_dir(job.job_id)
metadata_path = self.get_job_metadata_path(job.job_id)

job_dir = self.get_job_dir(job.job_id)
self.ensure_job_dir_exists(job.job_id)
job.get_results().download_files(target=job_dir)

with metadata_path.open("w", encoding="utf-8") as f:
json.dump(job_metadata, f, ensure_ascii=False)
# Proactively refresh bearer token (because task in thread will not be able to do that
job_con = job.connection
self._refresh_bearer_token(connection=job_con)

task = _JobDownloadTask(
job_id=job.job_id,
df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task
root_url=job_con.root_url,
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
download_dir=job_dir,
)
_log.info(f"Submitting download task {task} to download thread pool")

if self._worker_pool is None:
self._worker_pool = _JobManagerWorkerThreadPool()

self._worker_pool.submit_task(task=task, pool_name="job_download")

def on_job_error(self, job: BatchJob, row):
"""
Expand Down Expand Up @@ -797,6 +823,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
except Exception as e:
_log.error(f"Unexpected error while handling job {job.job_id}: {e}")

#TODO pull this functionality away from the manager to a general utility class? job dir creation could be reused for tje Jobdownload task
def get_job_dir(self, job_id: str) -> Path:
"""Path to directory where job metadata, results and error logs are be saved."""
return self._root_dir / f"job_{job_id}"
Expand Down
136 changes: 131 additions & 5 deletions openeo/extra/job_management/_thread_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path

import json
import urllib3.util

import openeo
Expand Down Expand Up @@ -99,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No
connection.authenticate_bearer_token(self.bearer_token)
return connection


@dataclass(frozen=True)
class _JobStartTask(ConnectedTask):
"""
Task for starting an openEO batch job (the `POST /jobs/<job_id>/result` request).
Expand Down Expand Up @@ -139,9 +141,51 @@ def execute(self) -> _TaskResult:
db_update={"status": "start_failed"},
stats_update={"start_job error": 1},
)

@dataclass(frozen=True)
class _JobDownloadTask(ConnectedTask):
"""
Task for downloading job results and metadata.

:param download_dir:
Root directory where job results and metadata will be downloaded.
"""
download_dir: Path = field(default=None, repr=False)

class _JobManagerWorkerThreadPool:
def execute(self) -> _TaskResult:

try:
job = self.get_connection(retry=True).job(self.job_id)

# Count assets (files to download)
file_count = len(job.get_results().get_assets())

# Download results
job.get_results().download_files(target=self.download_dir)

# Download metadata
job_metadata = job.describe()
metadata_path = self.download_dir / f"job_{self.job_id}.json"
with metadata_path.open("w", encoding="utf-8") as f:
json.dump(job_metadata, f, ensure_ascii=False)

_log.info(f"Job {self.job_id!r} results downloaded successfully")
return _TaskResult(
job_id=self.job_id,
df_idx=self.df_idx,
db_update={}, #TODO consider db updates?
stats_update={"job download": 1, "files downloaded": file_count},
)
except Exception as e:
_log.error(f"Failed to download results for job {self.job_id!r}: {e!r}")
return _TaskResult(
job_id=self.job_id,
df_idx=self.df_idx,
db_update={},
stats_update={"job download error": 1, "files downloaded": 0},
)

class _TaskThreadPool:
"""
Thread pool-based worker that manages the execution of asynchronous tasks.

Expand All @@ -150,12 +194,13 @@ class _JobManagerWorkerThreadPool:

:param max_workers:
Maximum number of concurrent threads to use for execution.
Defaults to 2.
Defaults to 1.
"""

def __init__(self, max_workers: int = 2):
def __init__(self, max_workers: int = 1, name: str = 'default'):
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = []
self._name = name

def submit_task(self, task: Task) -> None:
"""
Expand Down Expand Up @@ -206,9 +251,90 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe
_log.info("process_futures: %d tasks done, %d tasks remaining", len(results), len(to_keep))

self._future_task_pairs = to_keep
return results, len(to_keep)
return results

def number_pending_tasks(self) -> int:
"""Return the number of tasks that are still pending (not completed)."""
return len(self._future_task_pairs)

def shutdown(self) -> None:
"""Shuts down the thread pool gracefully."""
_log.info("Shutting down thread pool")
self._executor.shutdown(wait=True)


class _JobManagerWorkerThreadPool:

"""
Generic wrapper that manages multiple thread pools with a dict.
"""

def __init__(self, pool_configs: Optional[Dict[str, int]] = None):
"""
:param pool_configs: Dict of task_class_name -> max_workers
Example: {"_JobStartTask": 1, "_JobDownloadTask": 2}
"""
self._pools: Dict[str, _TaskThreadPool] = {}
self._pool_configs = pool_configs or {}

def _get_pool_name_for_task(self, task: Task) -> str:
"""
Get pool name from task class name.
"""
return task.__class__.__name__

def submit_task(self, task: Task, pool_name: str = "default") -> None:
"""
Submit a task to a specific pool.
Creates pool dynamically if it doesn't exist.

:param task: The task to execute
:param pool_name: Which pool to use (default, download, etc.)
"""
if pool_name not in self._pools:
# Create pool on-demand
max_workers = self._pool_configs.get(pool_name, 1) # Default 1 worker
self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers)
_log.info(f"Created pool '{pool_name}' with {max_workers} workers")

self._pools[pool_name].submit_task(task)

def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]:
"""
Process updates from ALL pools.
Returns: (all_results, dict of remaining tasks per pool)
"""
all_results = []

for pool_name, pool in self._pools.items():
results = pool.process_futures(timeout)
all_results.extend(results)

return all_results

def number_pending_tasks(self, pool_name: Optional[str] = None) -> int:
if pool_name:
pool = self._pools.get(pool_name)
return pool.number_pending_tasks() if pool else 0
else:
return sum(pool.number_pending_tasks() for pool in self._pools.values())

def shutdown(self, pool_name: Optional[str] = None) -> None:
"""
Shutdown pools.
If pool_name is None, shuts down all pools.
"""
if pool_name:
if pool_name in self._pools:
self._pools[pool_name].shutdown()
del self._pools[pool_name]
else:
for pool_name, pool in list(self._pools.items()):
pool.shutdown()
del self._pools[pool_name]

def list_pools(self) -> List[str]:
"""List all active pool names."""
return list(self._pools.keys())


8 changes: 4 additions & 4 deletions tests/extra/job_management/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def get_status(job_id, current_status):
assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime)

def test_process_threadworker_updates(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

# Submit tasks covering all cases
Expand Down Expand Up @@ -769,7 +769,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
assert caplog.messages == []

def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1}))
Expand Down Expand Up @@ -806,7 +806,7 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")]

def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]})
Expand All @@ -820,7 +820,7 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
assert stats == {}

def test_logs_on_invalid_update(self, tmp_path, caplog):
pool = _JobManagerWorkerThreadPool(max_workers=2)
pool = _JobManagerWorkerThreadPool()
stats = collections.defaultdict(int)

# Malformed db_update (not a dict unpackable via **)
Expand Down
Loading