diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 5b0c1dba1..b4b2bc0e5 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -32,6 +32,7 @@ from openeo.extra.job_management._thread_worker import ( _JobManagerWorkerThreadPool, _JobStartTask, + _JobDownloadTask ) from openeo.rest import OpenEoApiError from openeo.rest.auth.auth import BearerAuth @@ -175,6 +176,7 @@ def start_job( .. versionchanged:: 0.47.0 Added ``download_results`` parameter. + """ # Expected columns in the job DB dataframes. @@ -373,6 +375,9 @@ def run_loop(): ).values() ) > 0 + + or (self._worker_pool.number_pending_tasks() > 0) + and not self._stop_thread ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) @@ -398,7 +403,10 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): .. versionadded:: 0.32.0 """ - self._worker_pool.shutdown() + if self._worker_pool is not None or self._worker_pool.number_pending_tasks() > 0: + self._worker_pool.shutdown() + self._worker_pool = None + if self._thread is not None: self._stop_thread = True @@ -504,13 +512,15 @@ def run_jobs( self._worker_pool = _JobManagerWorkerThreadPool() + while ( sum( job_db.count_by_status( statuses=["not_started", "created", "queued_for_start", "queued", "running"] - ).values() - ) - > 0 + ).values()) > 0 + + or (self._worker_pool.number_pending_tasks() > 0) + ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 @@ -520,8 +530,10 @@ def run_jobs( time.sleep(self.poll_sleep) stats["sleep"] += 1 - # TODO; run post process after shutdown once more to ensure completion? + + self._worker_pool.shutdown() + self._worker_pool = None return stats @@ -567,7 +579,9 @@ def _job_update_loop( stats["job_db persist"] += 1 total_added += 1 - self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats) + if self._worker_pool is not None: + self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats) + # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads? for job, row in jobs_done: @@ -579,6 +593,7 @@ def _job_update_loop( for job, row in jobs_cancel: self.on_job_cancel(job, row) + def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None): """Helper method for launching jobs @@ -643,7 +658,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No df_idx=i, ) _log.info(f"Submitting task {task} to thread pool") - self._worker_pool.submit_task(task) + self._worker_pool.submit_task(task=task, pool_name="job_start") stats["job_queued_for_start"] += 1 df.loc[i, "status"] = "queued_for_start" @@ -689,7 +704,7 @@ def _process_threadworker_updates( :param stats: Dictionary accumulating statistic counters """ # Retrieve completed task results immediately - results, _ = worker_pool.process_futures(timeout=0) + results = worker_pool.process_futures(timeout=0) # Collect update dicts updates: List[Dict[str, Any]] = [] @@ -735,17 +750,28 @@ def on_job_done(self, job: BatchJob, row): :param job: The job that has finished. :param row: DataFrame row containing the job's metadata. """ - # TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use? if self._download_results: - job_metadata = job.describe() - job_dir = self.get_job_dir(job.job_id) - metadata_path = self.get_job_metadata_path(job.job_id) + job_dir = self.get_job_dir(job.job_id) self.ensure_job_dir_exists(job.job_id) - job.get_results().download_files(target=job_dir) - with metadata_path.open("w", encoding="utf-8") as f: - json.dump(job_metadata, f, ensure_ascii=False) + # Proactively refresh bearer token (because task in thread will not be able to do that + job_con = job.connection + self._refresh_bearer_token(connection=job_con) + + task = _JobDownloadTask( + job_id=job.job_id, + df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task + root_url=job_con.root_url, + bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, + download_dir=job_dir, + ) + _log.info(f"Submitting download task {task} to download thread pool") + + if self._worker_pool is None: + self._worker_pool = _JobManagerWorkerThreadPool() + + self._worker_pool.submit_task(task=task, pool_name="job_download") def on_job_error(self, job: BatchJob, row): """ @@ -797,6 +823,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row): except Exception as e: _log.error(f"Unexpected error while handling job {job.job_id}: {e}") + #TODO pull this functionality away from the manager to a general utility class? job dir creation could be reused for tje Jobdownload task def get_job_dir(self, job_id: str) -> Path: """Path to directory where job metadata, results and error logs are be saved.""" return self._root_dir / f"job_{job_id}" diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 6040fade1..c65d58287 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -7,7 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union +from pathlib import Path +import json import urllib3.util import openeo @@ -99,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No connection.authenticate_bearer_token(self.bearer_token) return connection - +@dataclass(frozen=True) class _JobStartTask(ConnectedTask): """ Task for starting an openEO batch job (the `POST /jobs//result` request). @@ -139,9 +141,51 @@ def execute(self) -> _TaskResult: db_update={"status": "start_failed"}, stats_update={"start_job error": 1}, ) + +@dataclass(frozen=True) +class _JobDownloadTask(ConnectedTask): + """ + Task for downloading job results and metadata. + :param download_dir: + Root directory where job results and metadata will be downloaded. + """ + download_dir: Path = field(default=None, repr=False) -class _JobManagerWorkerThreadPool: + def execute(self) -> _TaskResult: + + try: + job = self.get_connection(retry=True).job(self.job_id) + + # Count assets (files to download) + file_count = len(job.get_results().get_assets()) + + # Download results + job.get_results().download_files(target=self.download_dir) + + # Download metadata + job_metadata = job.describe() + metadata_path = self.download_dir / f"job_{self.job_id}.json" + with metadata_path.open("w", encoding="utf-8") as f: + json.dump(job_metadata, f, ensure_ascii=False) + + _log.info(f"Job {self.job_id!r} results downloaded successfully") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={}, #TODO consider db updates? + stats_update={"job download": 1, "files downloaded": file_count}, + ) + except Exception as e: + _log.error(f"Failed to download results for job {self.job_id!r}: {e!r}") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={}, + stats_update={"job download error": 1, "files downloaded": 0}, + ) + +class _TaskThreadPool: """ Thread pool-based worker that manages the execution of asynchronous tasks. @@ -150,12 +194,13 @@ class _JobManagerWorkerThreadPool: :param max_workers: Maximum number of concurrent threads to use for execution. - Defaults to 2. + Defaults to 1. """ - def __init__(self, max_workers: int = 2): + def __init__(self, max_workers: int = 1, name: str = 'default'): self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] + self._name = name def submit_task(self, task: Task) -> None: """ @@ -206,9 +251,90 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe _log.info("process_futures: %d tasks done, %d tasks remaining", len(results), len(to_keep)) self._future_task_pairs = to_keep - return results, len(to_keep) + return results + + def number_pending_tasks(self) -> int: + """Return the number of tasks that are still pending (not completed).""" + return len(self._future_task_pairs) def shutdown(self) -> None: """Shuts down the thread pool gracefully.""" _log.info("Shutting down thread pool") self._executor.shutdown(wait=True) + + +class _JobManagerWorkerThreadPool: + + """ + Generic wrapper that manages multiple thread pools with a dict. + """ + + def __init__(self, pool_configs: Optional[Dict[str, int]] = None): + """ + :param pool_configs: Dict of task_class_name -> max_workers + Example: {"_JobStartTask": 1, "_JobDownloadTask": 2} + """ + self._pools: Dict[str, _TaskThreadPool] = {} + self._pool_configs = pool_configs or {} + + def _get_pool_name_for_task(self, task: Task) -> str: + """ + Get pool name from task class name. + """ + return task.__class__.__name__ + + def submit_task(self, task: Task, pool_name: str = "default") -> None: + """ + Submit a task to a specific pool. + Creates pool dynamically if it doesn't exist. + + :param task: The task to execute + :param pool_name: Which pool to use (default, download, etc.) + """ + if pool_name not in self._pools: + # Create pool on-demand + max_workers = self._pool_configs.get(pool_name, 1) # Default 1 worker + self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers) + _log.info(f"Created pool '{pool_name}' with {max_workers} workers") + + self._pools[pool_name].submit_task(task) + + def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]: + """ + Process updates from ALL pools. + Returns: (all_results, dict of remaining tasks per pool) + """ + all_results = [] + + for pool_name, pool in self._pools.items(): + results = pool.process_futures(timeout) + all_results.extend(results) + + return all_results + + def number_pending_tasks(self, pool_name: Optional[str] = None) -> int: + if pool_name: + pool = self._pools.get(pool_name) + return pool.number_pending_tasks() if pool else 0 + else: + return sum(pool.number_pending_tasks() for pool in self._pools.values()) + + def shutdown(self, pool_name: Optional[str] = None) -> None: + """ + Shutdown pools. + If pool_name is None, shuts down all pools. + """ + if pool_name: + if pool_name in self._pools: + self._pools[pool_name].shutdown() + del self._pools[pool_name] + else: + for pool_name, pool in list(self._pools.items()): + pool.shutdown() + del self._pools[pool_name] + + def list_pools(self) -> List[str]: + """List all active pool names.""" + return list(self._pools.keys()) + + diff --git a/tests/extra/job_management/test_manager.py b/tests/extra/job_management/test_manager.py index 1d02afb1c..2c4162974 100644 --- a/tests/extra/job_management/test_manager.py +++ b/tests/extra/job_management/test_manager.py @@ -729,7 +729,7 @@ def get_status(job_id, current_status): assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime) def test_process_threadworker_updates(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) # Submit tasks covering all cases @@ -769,7 +769,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog): assert caplog.messages == [] def test_process_threadworker_updates_unknown(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) @@ -806,7 +806,7 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog): assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")] def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]}) @@ -820,7 +820,7 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): assert stats == {} def test_logs_on_invalid_update(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) # Malformed db_update (not a dict unpackable via **) diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index 52ee833f1..d0d684e3b 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -3,14 +3,18 @@ import time from dataclasses import dataclass from typing import Iterator +from pathlib import Path +from requests_mock import Mocker import pytest from openeo.extra.job_management._thread_worker import ( Task, + _TaskThreadPool, _JobManagerWorkerThreadPool, _JobStartTask, _TaskResult, + _JobDownloadTask ) from openeo.rest._testing import DummyBackend @@ -79,6 +83,98 @@ def test_hide_token(self, serializer): assert "job-123" in serialized assert secret not in serialized +class TestJobDownloadTask: + + + def test_job_download_success(self, requests_mock: Mocker, tmp_path: Path): + """ + Test a successful job download and verify file content and stats update. + """ + job_id = "job-007" + df_idx = 42 + + # We set up a dummy backend to simulate the job results and assert the expected calls are triggered + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + backend.next_result = b"The downloaded file content." + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["status"] = "finished" + + download_dir = tmp_path / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="dummy-token-7", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {'files downloaded': 1, "job download": 1} + + # Verify download content (crucial part of the unit test) + downloaded_file = download_dir / "result.data" + assert downloaded_file.exists() + assert downloaded_file.read_bytes() == b"The downloaded file content." + + + def test_job_download_failure(self, requests_mock: Mocker, tmp_path: Path): + """ + Test a failed download (e.g., bad connection) and verify error reporting. + """ + job_id = "job-008" + df_idx = 55 + + # Set up dummy backend to simulate failure during results listing + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + + #simulate and error when downloading the results + requests_mock.get( + f"https://openeo.dummy.test/jobs/{job_id}/results", + status_code=500, + json={"code": "InternalError", "message": "Failed to list results"}) + + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["finished"] = "error" + + download_dir = tmp_path / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="dummy-token-8", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {'files downloaded': 0, "job download error": 1} + + # Verify no file was created (or only empty/failed files) + assert not any(p.is_file() for p in download_dir.glob("*")) class NopTask(Task): """Do Nothing""" @@ -116,40 +212,37 @@ def execute(self) -> _TaskResult: return _TaskResult(job_id=self.job_id, df_idx=self.df_idx, db_update={"status": "all fine"}) -class TestJobManagerWorkerThreadPool: +class TestTaskThreadPool: @pytest.fixture - def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + def worker_pool(self) -> Iterator[_TaskThreadPool]: """Fixture for creating and cleaning up a worker thread pool.""" - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _TaskThreadPool() yield pool pool.shutdown() def test_no_tasks(self, worker_pool): - results, remaining = worker_pool.process_futures(timeout=10) + results = worker_pool.process_futures(timeout=10) assert results == [] - assert remaining == 0 def test_submit_and_process(self, worker_pool): worker_pool.submit_task(DummyTask(job_id="j-123", df_idx=0)) - results, remaining = worker_pool.process_futures(timeout=10) + results = worker_pool.process_futures(timeout=10) assert results == [ _TaskResult(job_id="j-123", df_idx=0, db_update={"status": "dummified"}, stats_update={"dummy": 1}), ] - assert remaining == 0 def test_submit_and_process_zero_timeout(self, worker_pool): worker_pool.submit_task(DummyTask(job_id="j-123", df_idx=0)) # Trigger context switch time.sleep(0.1) - results, remaining = worker_pool.process_futures(timeout=0) + results = worker_pool.process_futures(timeout=0) assert results == [ _TaskResult(job_id="j-123", df_idx=0, db_update={"status": "dummified"}, stats_update={"dummy": 1}), ] - assert remaining == 0 def test_submit_and_process_with_error(self, worker_pool): worker_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) - results, remaining = worker_pool.process_futures(timeout=10) + results = worker_pool.process_futures(timeout=10) assert results == [ _TaskResult( job_id="j-666", @@ -158,20 +251,18 @@ def test_submit_and_process_with_error(self, worker_pool): stats_update={"threaded task failed": 1}, ), ] - assert remaining == 0 + def test_submit_and_process_iterative(self, worker_pool): worker_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [_TaskResult(job_id="j-1", df_idx=1)] - assert remaining == 0 # Add some more worker_pool.submit_task(NopTask(job_id="j-22", df_idx=22)) worker_pool.submit_task(NopTask(job_id="j-222", df_idx=222)) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [_TaskResult(job_id="j-22", df_idx=22), _TaskResult(job_id="j-222", df_idx=222)] - assert remaining == 0 def test_submit_multiple_simple(self, worker_pool): # A bunch of dummy tasks @@ -179,7 +270,7 @@ def test_submit_multiple_simple(self, worker_pool): worker_pool.submit_task(NopTask(job_id=f"j-{j}", df_idx=j)) # Process all of them (non-zero timeout, which should be plenty of time for all of them to finish) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) expected = [_TaskResult(job_id=f"j-{j}", df_idx=j) for j in range(5)] assert sorted(results, key=lambda r: r.job_id) == expected @@ -200,25 +291,24 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool): ) # Initial state: nothing happened yet - results, remaining = worker_pool.process_futures(timeout=0) - assert (results, remaining) == ([], n) + results = worker_pool.process_futures(timeout=0) + assert results == [] # No changes even after timeout - results, remaining = worker_pool.process_futures(timeout=0.1) - assert (results, remaining) == ([], n) + results = worker_pool.process_futures(timeout=0.1) + assert results == [] # Set one event and wait for corresponding result events[0].set() - results, remaining = worker_pool.process_futures(timeout=0.1) + results = worker_pool.process_futures(timeout=0.1) assert results == [ _TaskResult(job_id="j-0", df_idx=0, db_update={"status": "all fine"}), ] - assert remaining == n - 1 # Release all but one event for j in range(n - 1): events[j].set() - results, remaining = worker_pool.process_futures(timeout=0.1) + results = worker_pool.process_futures(timeout=0.1) assert results == [ _TaskResult(job_id="j-1", df_idx=1, db_update={"status": "all fine"}), _TaskResult(job_id="j-2", df_idx=2, db_update={"status": "all fine"}), @@ -229,22 +319,20 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool): stats_update={"threaded task failed": 1}, ), ] - assert remaining == 1 # Release all events for j in range(n): events[j].set() - results, remaining = worker_pool.process_futures(timeout=0.1) + results = worker_pool.process_futures(timeout=0.1) assert results == [ _TaskResult(job_id="j-4", df_idx=4, db_update={"status": "all fine"}), ] - assert remaining == 0 def test_shutdown(self, worker_pool): # Before shutdown worker_pool.submit_task(NopTask(job_id="j-123", df_idx=0)) - results, remaining = worker_pool.process_futures(timeout=0.1) - assert (results, remaining) == ([_TaskResult(job_id="j-123", df_idx=0)], 0) + results = worker_pool.process_futures(timeout=0.1) + assert results == [_TaskResult(job_id="j-123", df_idx=0)] worker_pool.shutdown() @@ -258,7 +346,7 @@ def test_job_start_task(self, worker_pool, dummy_backend, caplog): task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token=None) worker_pool.submit_task(task) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [ _TaskResult( job_id="job-000", @@ -267,7 +355,6 @@ def test_job_start_task(self, worker_pool, dummy_backend, caplog): stats_update={"job start": 1}, ) ] - assert remaining == 0 assert caplog.messages == [] def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): @@ -278,13 +365,351 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token=None) worker_pool.submit_task(task) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [ _TaskResult( job_id="job-000", df_idx=0, db_update={"status": "start_failed"}, stats_update={"start_job error": 1} ) ] - assert remaining == 0 assert caplog.messages == [ "Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting for you, buddy')" ] + + + +class TestJobManagerWorkerThreadPool: + @pytest.fixture + def thread_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + """Fixture for creating and cleaning up a thread pool manager.""" + pool = _JobManagerWorkerThreadPool() + yield pool + pool.shutdown() + + @pytest.fixture + def configured_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + """Fixture with pre-configured pools.""" + pool = _JobManagerWorkerThreadPool( + pool_configs={ + "NopTask": 2, + "DummyTask": 3, + "BlockingTask": 1, + } + ) + yield pool + pool.shutdown() + + def test_init_empty_config(self): + """Test initialization with empty config.""" + pool = _JobManagerWorkerThreadPool() + assert pool._pools == {} + assert pool._pool_configs == {} + pool.shutdown() + + def test_init_with_config(self): + """Test initialization with pool configurations.""" + pool = _JobManagerWorkerThreadPool({ + "NopTask": 2, + "DummyTask": 3, + }) + # Pools should NOT be created until first use + assert pool._pools == {} + assert pool._pool_configs == { + "NopTask": 2, + "DummyTask": 3, + } + pool.shutdown() + + def test_submit_task_creates_pool(self, thread_pool): + """Test that submitting a task creates a pool dynamically.""" + task = NopTask(job_id="j-1", df_idx=1) + + assert thread_pool.list_pools() == [] + + # Submit task - should create pool + thread_pool.submit_task(task) + + # Pool should be created + assert thread_pool.list_pools() == ["default"] + assert "default" in thread_pool._pools + + # Process to complete the task + results = thread_pool.process_futures(timeout=0.1) + assert len(results) == 1 + assert results[0].job_id == "j-1" + + def test_submit_task_uses_config(self, configured_pool): + """Test that pool creation uses configuration.""" + task = NopTask(job_id="j-1", df_idx=1) + + # Submit task - should create pool with configured workers + configured_pool.submit_task(task, "NopTask") + + + + assert "NopTask" in configured_pool._pools + assert "NopTask" in configured_pool.list_pools() + assert "DummyTask" not in configured_pool.list_pools() + + def test_submit_multiple_task_types(self, thread_pool): + """Test submitting different task types to different pools.""" + # Submit different task types + task1 = NopTask(job_id="j-1", df_idx=1) + task2 = DummyTask(job_id="j-2", df_idx=2) + task3 = DummyTask(job_id="j-3", df_idx=3) + + thread_pool.submit_task(task1) # Goes to "NopTask" pool + thread_pool.submit_task(task2) # Goes to "DummyTask" pool + thread_pool.submit_task(task3, "seperate") # Goes to "DummyTask" pool + + # Should have 2 pools + pools = sorted(thread_pool.list_pools()) + assert pools == ["default", "seperate"] + + # Check pending tasks + assert thread_pool.number_pending_tasks() == 3 + assert thread_pool.number_pending_tasks("default") == 2 + assert thread_pool.number_pending_tasks("seperate") == 1 + + def test_process_futures_updates_empty(self, thread_pool): + """Test process futures with no pools.""" + results = thread_pool.process_futures(timeout=0) + assert results == [] + + def test_process_futures_updates_multiple_pools(self, thread_pool): + """Test processing updates across multiple pools.""" + # Submit tasks to different pools + thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) # NopTask pool + thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) # NopTask pool + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # DummyTask pool + + results = thread_pool.process_futures(timeout=0.1) + + assert len(results) == 3 + + nop_results = [r for r in results if r.job_id in ["j-1", "j-2"]] + dummy_results = [r for r in results if r.job_id == "j-3"] + assert len(nop_results) == 2 + assert len(dummy_results) == 1 + + # All tasks should be completed + def test_process_futures_updates_partial_completion(self): + """Test processing when some tasks are still running.""" + # Use a pool with blocking tasks + pool = _JobManagerWorkerThreadPool() + + # Create a blocking task + event = threading.Event() + blocking_task = BlockingTask(job_id="j-block", df_idx=0, event=event, success=True) + + # Create a quick task + quick_task = NopTask(job_id="j-quick", df_idx=1) + + pool.submit_task(blocking_task, "blocking") # BlockingTask pool + pool.submit_task(quick_task, "quick") # NopTask pool + + # Process with timeout=0 - only quick task should complete + results = pool.process_futures(timeout=0) + + # Only quick task completed + assert len(results) == 1 + assert results[0].job_id == "j-quick" + + # Blocking task still pending + assert pool.number_pending_tasks() == 1 + assert pool.number_pending_tasks("blocking") == 1 + + # Release blocking task and process again + event.set() + results2 = pool.process_futures(timeout=0.1) + + assert len(results2) == 1 + assert results2[0].job_id == "j-block" + + pool.shutdown() + + def test_num_pending_tasks(self, thread_pool): + """Test counting pending tasks.""" + # Initially empty + assert thread_pool.number_pending_tasks() == 0 + assert thread_pool.number_pending_tasks("default") == 0 + + # Add some tasks + thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) + thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3), "dummy") + + # Check totals + assert thread_pool.number_pending_tasks() == 3 + assert thread_pool.number_pending_tasks("default") == 2 + assert thread_pool.number_pending_tasks("dummy") == 1 + + # Process all + thread_pool.process_futures(timeout=0.1) + + # Should be empty + assert thread_pool.number_pending_tasks() == 0 + assert thread_pool.number_pending_tasks("default") == 0 + + def test_shutdown_specific_pool(self): + """Test shutting down a specific pool.""" + # Create fresh pool for destructive test + pool = _JobManagerWorkerThreadPool() + + # Create two pools + pool.submit_task(NopTask(job_id="j-1", df_idx=1), "notask") # NopTask pool + pool.submit_task(DummyTask(job_id="j-2", df_idx=2), "dummy") # DummyTask pool + + assert sorted(pool.list_pools()) == ["dummy", "notask"] + + # Shutdown NopTask pool only + pool.shutdown("notask") + + # Only DummyTask pool should remain + assert pool.list_pools() == ["dummy"] + + # Can't submit to shutdown pool + # Actually, it will create a new pool since we deleted it + pool.submit_task(NopTask(job_id="j-3", df_idx=3)) # Creates new NopTask pool + assert sorted(pool.list_pools()) == [ "default", "dummy"] + + pool.shutdown() + + def test_shutdown_all(self): + """Test shutting down all pools.""" + # Create fresh pool for destructive test + pool = _JobManagerWorkerThreadPool() + + # Create multiple pools + pool.submit_task(NopTask(job_id="j-1", df_idx=1), "notask") # NopTask pool + pool.submit_task(DummyTask(job_id="j-2", df_idx=2), "dummy") + + assert len(pool.list_pools()) == 2 + + # Shutdown all + pool.shutdown() + + assert pool.list_pools() == [] + assert len(pool._pools) == 0 + + def test_custom_get_pool_name(self): + """Test custom task class to verify pool name selection.""" + + @dataclass(frozen=True) + class CustomTask(Task): + def execute(self) -> _TaskResult: + return _TaskResult(job_id=self.job_id, df_idx=self.df_idx) + + pool = _JobManagerWorkerThreadPool() + + task = CustomTask(job_id="j-1", df_idx=1) + pool.submit_task(task, "custom_pool") + + # Pool should be named after class + assert pool.list_pools() == ["custom_pool"] + assert pool.number_pending_tasks() == 1 + + # Process it + results = pool.process_futures(timeout=0.1) + assert len(results) == 1 + assert results[0].job_id == "j-1" + + pool.shutdown() + + def test_concurrent_submissions(self, thread_pool): + """Test concurrent task submissions to same pool.""" + import concurrent.futures + + def submit_tasks(start_idx: int): + for i in range(5): + thread_pool.submit_task(NopTask(job_id=f"j-{start_idx + i}", df_idx=start_idx + i)) + + # Submit tasks from multiple threads + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(submit_tasks, i * 10) for i in range(3)] + concurrent.futures.wait(futures) + + # Should have all tasks in one pool + assert thread_pool.list_pools() == ["default"] + assert thread_pool.number_pending_tasks() == 15 + + # Process them all + results = thread_pool.process_futures(timeout=0.5) + + assert len(results) == 15 + + def test_pool_parallelism_with_blocking_tasks(self): + """Test that multiple workers allow parallel execution.""" + pool = _JobManagerWorkerThreadPool({ + "BlockingTask": 3, # 3 workers for blocking tasks + }) + + # Create multiple blocking tasks + events = [threading.Event() for _ in range(5)] + + for i, event in enumerate(events): + pool.submit_task(BlockingTask( + job_id=f"j-block-{i}", + df_idx=i, + event=event, + success=True + )) + + # Initially all pending + assert pool.number_pending_tasks() == 5 + + # Release all events at once + for event in events: + event.set() + + results = pool.process_futures(timeout=0.5) + assert len(results) == 5 + + for result in results: + assert result.job_id.startswith("j-block-") + + pool.shutdown() + + def test_task_with_error_handling(self, thread_pool): + """Test that task errors are properly handled in the pool.""" + # Submit a failing DummyTask (j-666 fails) + thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) + + # Process it + results = thread_pool.process_futures(timeout=0.1) + + # Should get error result + assert len(results) == 1 + result = results[0] + assert result.job_id == "j-666" + assert result.db_update == {"status": "threaded task failed"} + assert result.stats_update == {"threaded task failed": 1} + + def test_mixed_success_and_error_tasks(self, thread_pool): + """Test mix of successful and failing tasks.""" + # Submit mix of tasks + thread_pool.submit_task(DummyTask(job_id="j-1", df_idx=1)) # Success + thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=2)) # Failure + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # Success + + # Process all + results = thread_pool.process_futures(timeout=0.1) + + # Should get 3 results + assert len(results) == 3 + + # Check results + success_results = [r for r in results if r.job_id != "j-666"] + error_results = [r for r in results if r.job_id == "j-666"] + + assert len(success_results) == 2 + assert len(error_results) == 1 + + # Verify success results + for result in success_results: + assert result.db_update == {"status": "dummified"} + assert result.stats_update == {"dummy": 1} + + # Verify error result + error_result = error_results[0] + assert error_result.db_update == {"status": "threaded task failed"} + assert error_result.stats_update == {"threaded task failed": 1} \ No newline at end of file