diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f6d707a9e..ebe4e5eff 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -20,6 +20,7 @@ openslide-bin>=4.0.0.2 openslide-python>=1.4.0 pandas>=2.0.0 pillow>=9.3.0 +pyarrow>=22.0.0 pydicom>=2.3.1 # Used by wsidicom pyyaml>=6.0 requests>=2.28.1 diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index e59462bba..bfa63090e 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -135,6 +135,38 @@ def test_incorrect_output_type() -> NoReturn: ) +def test_incorrect_output_type_save_dir() -> NoReturn: + """Test EngineABC for None output_type and output type zarr/annotationstore.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + + with pytest.raises( + ValueError, + match=r".*Please provide save_dir for output_type=zarr*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="zarr", + ) + + with pytest.raises( + ValueError, + match=r".*Please provide save_dir for output_type=annotationstore*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="annotationstore", + ) + + def test_pretrained_ioconfig() -> NoReturn: """Test EngineABC initialization with pretrained model name in the toolbox.""" pretrained_model = "alexnet-kather100k" diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py new file mode 100644 index 000000000..75eecda19 --- /dev/null +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -0,0 +1,174 @@ +"""Test tiatoolbox.models.engine.nucleus_instance_segmentor.""" + +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any, Final + +import numpy as np +import torch +import zarr + +from tiatoolbox.annotation.storage import SQLiteStore +from tiatoolbox.models import NucleusInstanceSegmentor +from tiatoolbox.wsicore import WSIReader + +device = "cuda:0" if torch.cuda.is_available() else "cpu" +OutputType = dict[str, Any] | Any + + +def assert_output_lengths(output: OutputType, expected_counts: Sequence[int]) -> None: + """Assert lengths of output dict fields against expected counts.""" + for field in ["box", "centroid", "contour", "prob", "type"]: + for i, expected in enumerate(expected_counts): + assert len(output[field][i]) == expected, f"{field}[{i}] mismatch" + + +def assert_output_equal( + output_a: OutputType, + output_b: OutputType, + fields: Sequence[str], + indices_a: Sequence[int], + indices_b: Sequence[int], +) -> None: + """Assert equality of arrays across outputs for given fields/indices.""" + for field in fields: + for i_a, i_b in zip(indices_a, indices_b, strict=False): + left = output_a[field][i_a] + right = output_b[field][i_b] + assert all( + np.array_equal(a, b) for a, b in zip(left, right, strict=False) + ), f"{field}[{i_a}] vs {field}[{i_b}] mismatch" + + +def assert_predictions_and_boxes( + output: OutputType, expected_counts: Sequence[int], *, is_zarr: bool = False +) -> None: + """Assert predictions maxima and box lengths against expected counts.""" + # predictions maxima + for idx, expected in enumerate(expected_counts): + if is_zarr and idx == 2: + # zarr output doesn't store predictions for patch 2 + continue + assert np.max(output["predictions"][idx][:]) == expected, ( + f"predictions[{idx}] mismatch" + ) + + # box lengths + for idx, expected in enumerate(expected_counts): + if is_zarr and idx < 2: + # for zarr, compare boxes only for patches 0 and 1 + continue + assert len(output["box"][idx]) == expected, f"box[{idx}] mismatch" + + +def test_functionality_patch_mode( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Patch mode functionality test for nuclei instance segmentor.""" + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi = WSIReader.open(mini_wsi_svs) + size = (256, 256) + resolution = 0.25 + units: Final = "mpp" + + patch1 = mini_wsi.read_rect( + location=(0, 0), size=size, resolution=resolution, units=units + ) + patch2 = mini_wsi.read_rect( + location=(512, 512), size=size, resolution=resolution, units=units + ) + patch3 = np.zeros_like(patch1) + patches = np.stack([patch1, patch2, patch3], axis=0) + + inst_segmentor = NucleusInstanceSegmentor( + batch_size=1, num_workers=0, model="hovernet_fast-pannuke" + ) + output_dict = inst_segmentor.run( + images=patches, patch_mode=True, device=device, output_type="dict" + ) + + expected_counts = [41, 17, 0] + + assert_predictions_and_boxes(output_dict, expected_counts, is_zarr=False) + assert_output_lengths(output_dict, expected_counts) + + # Zarr output comparison + output_zarr = inst_segmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="zarr", + save_dir=track_tmp_path / "patch_output_zarr", + ) + output_zarr = zarr.open(output_zarr, mode="r") + assert_predictions_and_boxes(output_zarr, expected_counts, is_zarr=True) + + assert_output_equal( + output_zarr, + output_dict, + fields=["box", "centroid", "contour", "prob", "type"], + indices_a=[0, 1, 2], + indices_b=[0, 1, 2], + ) + + # AnnotationStore output comparison + output_ann = inst_segmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="annotationstore", + save_dir=track_tmp_path / "patch_output_annotationstore", + ) + assert len(output_ann) == 3 + assert output_ann[0] == track_tmp_path / "patch_output_annotationstore" / "0.db" + + for patch_idx, db_path in enumerate(output_ann): + assert ( + db_path + == track_tmp_path / "patch_output_annotationstore" / f"{patch_idx}.db" + ) + store_ = SQLiteStore.open(db_path) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + annotations_list = list(annotations_) + if expected_counts[patch_idx] > 0: + assert "Polygon" in annotations_geometry_type + + # Build result dict from annotation properties + result = {} + for ann in annotations_list: + for key, value in ann.properties.items(): + result.setdefault(key, []).append(value) + result["contour"] = [ + list(poly.exterior.coords) + for poly in (a.geometry for a in annotations_list) + ] + + # wrap it to make it compatible to assert_output_lengths + result_ = { + field: [result[field]] + for field in ["box", "centroid", "contour", "prob", "type"] + } + + # Lengths and equality checks for this patch + assert_output_lengths(result_, [expected_counts[patch_idx]]) + assert_output_equal( + result_, + output_dict, + fields=["box", "centroid", "prob", "type"], + indices_a=[0], + indices_b=[patch_idx], + ) + + # Contour check (discard last point) + assert all( + np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int)) + for a, b in zip( + result["contour"], output_dict["contour"][patch_idx], strict=False + ) + ) + else: + assert annotations_geometry_type == [] + assert annotations_list == [] diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 19d02e7a5..6aa592ad0 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -6,7 +6,11 @@ from collections import OrderedDict import cv2 +import dask +import dask.array as da +import dask.dataframe as dd import numpy as np +import pandas as pd import torch import torch.nn.functional as F # noqa: N812 from scipy import ndimage @@ -22,6 +26,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils.misc import get_bounding_box +dask.config.set({"dataframe.convert-string": False}) + class TFSamepaddingLayer(nn.Module): """To align with tensorflow `same` padding. @@ -776,11 +782,34 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: tp_map = None np_map, hv_map = raw_maps - pred_type = tp_map + np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map + hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map + pred_type = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) - return pred_inst, nuc_inst_info_dict + if not nuc_inst_info_dict: + nuc_inst_info_dict = { # inst_id should start at 1 + "box": da.empty(shape=0), + "centroid": da.empty(shape=0), + "contour": da.empty(shape=0), + "prob": da.empty(shape=0), + "type": da.empty(shape=0), + } + return pred_inst, nuc_inst_info_dict + + # dask dataframe does not support transpose + nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() + + # create dask dataframe + nuc_inst_info_dd = dd.from_pandas(nuc_inst_info_df) + + # reinitialize nuc_inst_info_dict + nuc_inst_info_dict_ = {} + for key in nuc_inst_info_df.columns: + nuc_inst_info_dict_[key] = nuc_inst_info_dd[key].to_dask_array(lengths=True) + + return pred_inst, nuc_inst_info_dict_ @staticmethod def infer_batch( # skipcq: PYL-W0221 diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index d7ac9ddfc..a6ab31787 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -46,6 +46,7 @@ import zarr from dask import compute from dask.diagnostics import ProgressBar +from numcodecs import Pickle from torch import nn from typing_extensions import Unpack @@ -71,6 +72,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.type_hints import IntPair, Resolution, Units +dask.config.set({"dataframe.convert-string": False}) + class EngineABCRunParams(TypedDict, total=False): """Parameters for configuring the :func:`EngineABC.run()` method. @@ -518,7 +521,7 @@ def infer_patches( coordinates = [] # Main output dictionary - raw_predictions = dict(zip(keys, [[]] * len(keys), strict=False)) + raw_predictions = {key: [] for key in keys} # Inference loop tqdm = get_tqdm() @@ -704,13 +707,29 @@ def save_predictions( keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] write_tasks = [] for key in keys_to_compute: - dask_array = processed_predictions[key].rechunk("auto") - task = dask_array.to_zarr( - url=save_path, - component=key, - compute=False, - ) - write_tasks.append(task) + dask_output = processed_predictions[key] + if isinstance(dask_output, da.Array): + dask_output = dask_output.rechunk("auto") + task = dask_output.to_zarr( + url=save_path, component=key, compute=False, object_codec=None + ) + write_tasks.append(task) + + if isinstance(dask_output, list) and all( + isinstance(dask_array, da.Array) for dask_array in dask_output + ): + for i, dask_array in enumerate(dask_output): + object_codec = ( + Pickle() if dask_array.dtype == "object" else None + ) + task = dask_array.to_zarr( + url=save_path, + component=f"{key}/{i}", + compute=False, + object_codec=object_codec, + ) + write_tasks.append(task) + msg = f"Saving output to {save_path}." logger.info(msg=msg) with ProgressBar(): @@ -1174,6 +1193,9 @@ def _update_run_params( If an unsupported output_type is provided. ValueError: If required configuration or input parameters are missing. + ValueError: + If save_dir is not provided and output_type is "zarr" + or "annotationstore". """ for key in kwargs: @@ -1214,6 +1236,10 @@ def _update_run_params( ) logger.info(msg) + if save_dir is None and output_type.lower() in ["zarr", "annotationstore"]: + msg = f"Please provide save_dir for output_type={output_type}" + raise ValueError(msg) + self.images = self._validate_images_masks(images=images) if masks is not None: diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index ce74355ae..1293821d5 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -4,23 +4,44 @@ import uuid from collections import deque +from pathlib import Path from typing import TYPE_CHECKING +import dask.array as da + # replace with the sql database once the PR in place import joblib import numpy as np import torch import tqdm +import zarr from shapely.geometry import box as shapely_box +from shapely.geometry import shape as feature2geometry from shapely.strtree import STRtree - -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset -from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor +from typing_extensions import Unpack + +from tiatoolbox import DuplicateFilter, logger +from tiatoolbox.annotation import SQLiteStore +from tiatoolbox.annotation.storage import Annotation +from tiatoolbox.models.engine.semantic_segmentor import ( + SemanticSegmentor, + SemanticSegmentorRunParams, +) from tiatoolbox.tools.patchextraction import PatchExtractor +from tiatoolbox.utils.misc import get_tqdm, make_valid_poly +from tiatoolbox.wsicore.wsireader import is_zarr if TYPE_CHECKING: # pragma: no cover + import os from collections.abc import Callable + from torch.utils.data import DataLoader + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore import WSIReader + + from .engine_abc import EngineABCRunParams from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig @@ -372,38 +393,313 @@ class NucleusInstanceSegmentor(SemanticSegmentor): def __init__( self: NucleusInstanceSegmentor, + model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, + num_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, - auto_generate_mask: bool = False, ) -> None: """Initialize :class:`NucleusInstanceSegmentor`.""" super().__init__( - batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, model=model, - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, verbose=verbose, - auto_generate_mask=auto_generate_mask, - dataset_class=dataset_class, ) - # default is None in base class and is un-settable - # hence we redefine the namespace here - self.num_postproc_workers = ( - num_postproc_workers if num_postproc_workers > 0 else None + + def infer_patches( + self: NucleusInstanceSegmentor, + dataloader: DataLoader, + *, + return_coordinates: bool = False, + ) -> dict[str, list[da.Array]]: + """Run model inference on image patches and return predictions. + + This method performs batched inference using a PyTorch DataLoader, + and accumulates predictions in Dask arrays. It supports optional inclusion + of coordinates and labels in the output. + + Args: + dataloader (DataLoader): + PyTorch DataLoader containing image patches for inference. + return_coordinates (bool): + Whether to include coordinates in the output. Required when + called by `infer_wsi` and `patch_mode` is False. + + Returns: + dict[str, dask.array.Array]: + Dictionary containing prediction results as Dask arrays. + Keys include: + - "probabilities": Model output probabilities. + - "labels": Ground truth labels (if `return_labels` is True). + - "coordinates": Patch coordinates (if `return_coordinates` is + True). + + """ + keys = ["probabilities"] + labels, coordinates = [], [] + + # Expected number of outputs from the model + batch_output = self.model.infer_batch( + self.model, + torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), + device=self.device, ) - # adding more runtime placeholder - self._wsi_inst_info = None - self._futures = [] + num_expected_output = len(batch_output) + probabilities = [[] for _ in range(num_expected_output)] + + if return_coordinates: + keys.append("coordinates") + coordinates = [] + + # Main output dictionary + raw_predictions = {key: [] for key in keys} + raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else self.dataloader + ) + + for batch_data in tqdm_loop: + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + for i in range(num_expected_output): + probabilities[i].append( + da.from_array( + batch_output[i], # probabilities + ) + ) + + if return_coordinates: + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + if self.return_labels: + labels.append(da.from_array(np.array(batch_data["label"]))) + + for i in range(num_expected_output): + raw_predictions["probabilities"][i] = da.concatenate( + probabilities[i], axis=0 + ) + + if return_coordinates: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + + return raw_predictions + + def _run_patch_mode( + self: NucleusInstanceSegmentor, + output_type: str, + save_dir: Path, + **kwargs: EngineABCRunParams, + ) -> dict | AnnotationStore | Path: + """Run the engine in patch mode. + + This method performs inference on image patches, post-processes the predictions, + and saves the output in the specified format. + + Args: + output_type (str): + Desired output format. Supported values are "dict", "zarr", + and "annotationstore". + save_dir (Path): + Directory to save the output files. + **kwargs (EngineABCRunParams): + Additional runtime parameters including: + - output_file: Name of the output file. + - scale_factor: Scaling factor for annotations. + - class_dict: Mapping of class indices to names. + + Returns: + dict | AnnotationStore | Path: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "annotationstore": returns an AnnotationStore + or path to .db file. + + """ + save_path = None + if save_dir: + output_file = Path(kwargs.get("output_file", "output.zarr")) + save_path = save_dir / (str(output_file.stem) + ".zarr") + + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + + self.dataloader = self.get_dataloader( + images=self.images, + masks=self.masks, + labels=self.labels, + patch_mode=True, + ioconfig=self._ioconfig, + ) + raw_predictions = self.infer_patches( + dataloader=self.dataloader, + return_coordinates=output_type == "annotationstore", + ) + + raw_predictions = self.post_process_patches( + raw_predictions=raw_predictions, + prediction_shape=None, + prediction_dtype=None, + **kwargs, + ) + + logger.removeFilter(duplicate_filter) + + out = self.save_predictions( + processed_predictions=raw_predictions, + output_type=output_type, + save_path=save_path, + **kwargs, + ) + + if save_path: + msg = f"Output file saved at {out}." + logger.info(msg=msg) + return out + + def post_process_patches( # skipcq: PYL-R0201 + self: NucleusInstanceSegmentor, + raw_predictions: dict, + prediction_shape: tuple[int, ...], # noqa: ARG002 + prediction_dtype: type, # noqa: ARG002 + **kwargs: Unpack[EngineABCRunParams], # noqa: ARG002 + ) -> dict: + """Post-process raw patch predictions from inference. + + This method applies a post-processing function (e.g., smoothing, filtering) + to the raw model predictions. It supports delayed execution using Dask + and returns a Dask array for efficient computation. + + Args: + raw_predictions (dask.array.Array): + Raw model predictions as a dask array. + prediction_shape (tuple[int, ...]): + Shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (EngineABCRunParams): + Additional runtime parameters used for post-processing. + + Returns: + dask.array.Array: + Post-processed predictions as a Dask array. + + """ + probabilities = raw_predictions["probabilities"] + predictions = [[] for _ in range(probabilities[0].shape[0])] + inst_dict = [[{}] for _ in range(probabilities[0].shape[0])] + for idx in range(probabilities[0].shape[0]): + predictions[idx], inst_dict[idx] = self.model.postproc_func( + [probabilities[0][idx], probabilities[1][idx], probabilities[2][idx]] + ) + + raw_predictions["predictions"] = da.stack(predictions, axis=0) + for key in inst_dict[0]: + raw_predictions[key] = [d[key] for d in inst_dict] + + return raw_predictions + + def save_predictions( + self: NucleusInstanceSegmentor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path | list[Path]: + """Save semantic segmentation predictions to disk or return them in memory.""" + # Conversion to annotationstore uses a different function + # for NucleusInstanceSegmentor. + if output_type.lower() != "annotationstore": + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) + + return_probabilities = kwargs.get("return_probabilities", False) + output_type_ = ( + "zarr" + if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities + else "dict" + ) + + # This runs dask.compute and returns numpy arrays + # for saving annotationstore output. + processed_predictions = super().save_predictions( + processed_predictions, + output_type=output_type_, + save_path=save_path.with_suffix(".zarr"), + **kwargs, + ) + + if isinstance(processed_predictions, Path): + processed_predictions = zarr.open(str(processed_predictions), mode="r") + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + + # Need to add support for zarr conversion. + save_paths = [] + + logger.info("Saving predictions as AnnotationStore.") + + # Not required for annotationstore + processed_predictions.pop("predictions") + if self.patch_mode: + for i, predictions in enumerate( + zip(*processed_predictions.values(), strict=False) + ): + predictions_ = dict( + zip(processed_predictions.keys(), predictions, strict=False) + ) + if isinstance(self.images[i], Path): + output_path = save_path.parent / (self.images[i].stem + ".db") + else: + output_path = save_path.parent / (str(i) + ".db") + + origin = predictions_.pop("coordinates")[:2] + store = SQLiteStore() + store = dict_to_store( + store=store, + processed_predictions=predictions_, + class_dict=class_dict, + scale_factor=scale_factor, + origin=origin, + ) + + store.commit() + store.dump(output_path) + + save_paths.append(output_path) + + if return_probabilities: + msg = ( + f"Probability maps cannot be saved as AnnotationStore. " + f"To visualise heatmaps in TIAToolbox Visualization tool," + f"convert heatmaps in {save_path} to ome.tiff using" + f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." + ) + logger.info(msg) + + return save_paths @staticmethod def _get_tile_info( @@ -812,3 +1108,70 @@ def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: # manually call the callback rather than # attaching it when receiving/creating the future callback(*result) + + def run( + self: NucleusInstanceSegmentor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the nucleus instance segmentor engine on input images.""" + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) + + +def dict_to_store( + store: SQLiteStore, + processed_predictions: dict, + class_dict: dict | None = None, + origin: tuple[float, float] = (0, 0), + scale_factor: tuple[float, float] = (1, 1), +) -> AnnotationStore: + """Helper function to convert dict to store.""" + contour = processed_predictions.pop("contour") + + ann = [] + for i, contour_ in enumerate(contour): + ann_ = Annotation( + make_valid_poly( + feature2geometry( + { + "type": processed_predictions.get("geom_type", "Polygon"), + "coordinates": scale_factor * np.array([contour_]), + }, + ), + tuple(origin), + ), + { + prop: ( + class_dict[processed_predictions[prop][i]] + if prop == "type" and class_dict is not None + # Intention is convert arrays to list + # There might be int or float values which need to be + # converted to arrays first and then apply tolist(). + else np.array(processed_predictions[prop][i]).tolist() + ) + for prop in processed_predictions + }, + ) + ann.append(ann_) + logger.info("Added %d annotations.", len(ann)) + store.append_many(ann) + + return store diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index e429940d7..bc41f6cea 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -587,7 +587,7 @@ def save_predictions( output_type: str, save_path: Path | None = None, **kwargs: Unpack[SemanticSegmentorRunParams], - ) -> dict | AnnotationStore | Path: + ) -> dict | AnnotationStore | Path | list[Path]: """Save semantic segmentation predictions to disk or return them in memory. This method saves predictions in one of the supported formats: @@ -645,11 +645,11 @@ def save_predictions( Whether to enable verbose logging. Returns: - dict | AnnotationStore | Path: + dict | AnnotationStore | Path | list[Path]: - If output_type is "dict": returns predictions as a dictionary. - If output_type is "zarr": returns path to saved Zarr file. - If output_type is "annotationstore": returns AnnotationStore - or path to .db file. + or path or list of paths to .db file. """ # Conversion to annotationstore uses a different function for SemanticSegmentor