Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7d50a01
:new: Define `NucleusInstanceSegmentor`
shaneahmed Oct 23, 2025
9d9bf7d
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Nov 6, 2025
9d13fcb
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Nov 17, 2025
03b2964
:test_tube: Add initial test for nucleus instance segmentor
shaneahmed Nov 17, 2025
dae9213
:test_tube: Test issues with raw output in patch mode
shaneahmed Nov 19, 2025
4bc33b7
:test_tube: Test issues with raw output in patch mode
shaneahmed Nov 20, 2025
2797ff9
:white_check_mark: Test patch mode with dict output
shaneahmed Nov 20, 2025
da6a1ea
:white_check_mark: Test patch mode with dict and zarr output
shaneahmed Nov 24, 2025
5e14877
:lipstick: log output if save path is requested
shaneahmed Nov 24, 2025
843841c
:goal_net: Catch error with no save_dir and output_type is zarr or an…
shaneahmed Nov 28, 2025
85f0bc7
:test_tube: Add test for patch annotations
shaneahmed Dec 1, 2025
a29f2ea
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Dec 1, 2025
4fe3a49
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Dec 2, 2025
caece3f
Merge remote-tracking branch 'origin/dev-define-nucleus-instance-segm…
shaneahmed Dec 4, 2025
e6dc905
:test_tube: Add failing test for annotationstore conversion
shaneahmed Dec 4, 2025
0528942
:white_check_mark: Add functionality patch mode annotations
shaneahmed Dec 4, 2025
200c24c
:pushpin: remove dask dataframe dependency
shaneahmed Dec 5, 2025
63f0589
:pushpin: Add `pyarrow` dependency
shaneahmed Dec 5, 2025
8321e0e
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Dec 5, 2025
0f8d49e
:white_check_mark: Add checks for annotationstore output
shaneahmed Dec 5, 2025
a911d3c
:art: Improve structure of the test
shaneahmed Dec 5, 2025
18eb3b4
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Dec 6, 2025
f69b57d
Merge branch 'dev-define-engines-abc' into dev-define-nucleus-instanc…
shaneahmed Dec 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ openslide-bin>=4.0.0.2
openslide-python>=1.4.0
pandas>=2.0.0
pillow>=9.3.0
pyarrow>=22.0.0
pydicom>=2.3.1 # Used by wsidicom
pyyaml>=6.0
requests>=2.28.1
Expand Down
32 changes: 32 additions & 0 deletions tests/engines/test_engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,38 @@ def test_incorrect_output_type() -> NoReturn:
)


def test_incorrect_output_type_save_dir() -> NoReturn:
"""Test EngineABC for None output_type and output type zarr/annotationstore."""
pretrained_model = "alexnet-kather100k"

# Test engine run without ioconfig
eng = TestEngineABC(model=pretrained_model)

with pytest.raises(
ValueError,
match=r".*Please provide save_dir for output_type=zarr*",
):
_ = eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
on_gpu=False,
patch_mode=True,
ioconfig=None,
output_type="zarr",
)

with pytest.raises(
ValueError,
match=r".*Please provide save_dir for output_type=annotationstore*",
):
_ = eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
on_gpu=False,
patch_mode=True,
ioconfig=None,
output_type="annotationstore",
)


def test_pretrained_ioconfig() -> NoReturn:
"""Test EngineABC initialization with pretrained model name in the toolbox."""
pretrained_model = "alexnet-kather100k"
Expand Down
174 changes: 174 additions & 0 deletions tests/engines/test_nucleus_instance_segmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Test tiatoolbox.models.engine.nucleus_instance_segmentor."""

from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any, Final

import numpy as np
import torch
import zarr

from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models import NucleusInstanceSegmentor
from tiatoolbox.wsicore import WSIReader

device = "cuda:0" if torch.cuda.is_available() else "cpu"
OutputType = dict[str, Any] | Any


def assert_output_lengths(output: OutputType, expected_counts: Sequence[int]) -> None:
"""Assert lengths of output dict fields against expected counts."""
for field in ["box", "centroid", "contour", "prob", "type"]:
for i, expected in enumerate(expected_counts):
assert len(output[field][i]) == expected, f"{field}[{i}] mismatch"


def assert_output_equal(
output_a: OutputType,
output_b: OutputType,
fields: Sequence[str],
indices_a: Sequence[int],
indices_b: Sequence[int],
) -> None:
"""Assert equality of arrays across outputs for given fields/indices."""
for field in fields:
for i_a, i_b in zip(indices_a, indices_b, strict=False):
left = output_a[field][i_a]
right = output_b[field][i_b]
assert all(
np.array_equal(a, b) for a, b in zip(left, right, strict=False)
), f"{field}[{i_a}] vs {field}[{i_b}] mismatch"


def assert_predictions_and_boxes(
output: OutputType, expected_counts: Sequence[int], *, is_zarr: bool = False
) -> None:
"""Assert predictions maxima and box lengths against expected counts."""
# predictions maxima
for idx, expected in enumerate(expected_counts):
if is_zarr and idx == 2:
# zarr output doesn't store predictions for patch 2
continue
assert np.max(output["predictions"][idx][:]) == expected, (
f"predictions[{idx}] mismatch"
)

# box lengths
for idx, expected in enumerate(expected_counts):
if is_zarr and idx < 2:
# for zarr, compare boxes only for patches 0 and 1
continue
assert len(output["box"][idx]) == expected, f"box[{idx}] mismatch"


def test_functionality_patch_mode(
remote_sample: Callable, track_tmp_path: Path
) -> None:
"""Patch mode functionality test for nuclei instance segmentor."""
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
mini_wsi = WSIReader.open(mini_wsi_svs)
size = (256, 256)
resolution = 0.25
units: Final = "mpp"

patch1 = mini_wsi.read_rect(
location=(0, 0), size=size, resolution=resolution, units=units
)
patch2 = mini_wsi.read_rect(
location=(512, 512), size=size, resolution=resolution, units=units
)
patch3 = np.zeros_like(patch1)
patches = np.stack([patch1, patch2, patch3], axis=0)

inst_segmentor = NucleusInstanceSegmentor(
batch_size=1, num_workers=0, model="hovernet_fast-pannuke"
)
output_dict = inst_segmentor.run(
images=patches, patch_mode=True, device=device, output_type="dict"
)

expected_counts = [41, 17, 0]

assert_predictions_and_boxes(output_dict, expected_counts, is_zarr=False)
assert_output_lengths(output_dict, expected_counts)

# Zarr output comparison
output_zarr = inst_segmentor.run(
images=patches,
patch_mode=True,
device=device,
output_type="zarr",
save_dir=track_tmp_path / "patch_output_zarr",
)
output_zarr = zarr.open(output_zarr, mode="r")
assert_predictions_and_boxes(output_zarr, expected_counts, is_zarr=True)

assert_output_equal(
output_zarr,
output_dict,
fields=["box", "centroid", "contour", "prob", "type"],
indices_a=[0, 1, 2],
indices_b=[0, 1, 2],
)

# AnnotationStore output comparison
output_ann = inst_segmentor.run(
images=patches,
patch_mode=True,
device=device,
output_type="annotationstore",
save_dir=track_tmp_path / "patch_output_annotationstore",
)
assert len(output_ann) == 3
assert output_ann[0] == track_tmp_path / "patch_output_annotationstore" / "0.db"

for patch_idx, db_path in enumerate(output_ann):
assert (
db_path
== track_tmp_path / "patch_output_annotationstore" / f"{patch_idx}.db"
)
store_ = SQLiteStore.open(db_path)
annotations_ = store_.values()
annotations_geometry_type = [
str(annotation_.geometry_type) for annotation_ in annotations_
]
annotations_list = list(annotations_)
if expected_counts[patch_idx] > 0:
assert "Polygon" in annotations_geometry_type

# Build result dict from annotation properties
result = {}
for ann in annotations_list:
for key, value in ann.properties.items():
result.setdefault(key, []).append(value)
result["contour"] = [
list(poly.exterior.coords)
for poly in (a.geometry for a in annotations_list)
]

# wrap it to make it compatible to assert_output_lengths
result_ = {
field: [result[field]]
for field in ["box", "centroid", "contour", "prob", "type"]
}

# Lengths and equality checks for this patch
assert_output_lengths(result_, [expected_counts[patch_idx]])
assert_output_equal(
result_,
output_dict,
fields=["box", "centroid", "prob", "type"],
indices_a=[0],
indices_b=[patch_idx],
)

# Contour check (discard last point)
assert all(
np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int))
for a, b in zip(
result["contour"], output_dict["contour"][patch_idx], strict=False
)
)
else:
assert annotations_geometry_type == []
assert annotations_list == []
33 changes: 31 additions & 2 deletions tiatoolbox/models/architecture/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from collections import OrderedDict

import cv2
import dask
import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F # noqa: N812
from scipy import ndimage
Expand All @@ -22,6 +26,8 @@
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils.misc import get_bounding_box

dask.config.set({"dataframe.convert-string": False})


class TFSamepaddingLayer(nn.Module):
"""To align with tensorflow `same` padding.
Expand Down Expand Up @@ -776,11 +782,34 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]:
tp_map = None
np_map, hv_map = raw_maps

pred_type = tp_map
np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map
hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map
pred_type = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map
pred_inst = HoVerNet._proc_np_hv(np_map, hv_map)
nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type)

return pred_inst, nuc_inst_info_dict
if not nuc_inst_info_dict:
nuc_inst_info_dict = { # inst_id should start at 1
"box": da.empty(shape=0),
"centroid": da.empty(shape=0),
"contour": da.empty(shape=0),
"prob": da.empty(shape=0),
"type": da.empty(shape=0),
}
return pred_inst, nuc_inst_info_dict

# dask dataframe does not support transpose
nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose()

# create dask dataframe
nuc_inst_info_dd = dd.from_pandas(nuc_inst_info_df)

# reinitialize nuc_inst_info_dict
nuc_inst_info_dict_ = {}
for key in nuc_inst_info_df.columns:
nuc_inst_info_dict_[key] = nuc_inst_info_dd[key].to_dask_array(lengths=True)

return pred_inst, nuc_inst_info_dict_

@staticmethod
def infer_batch( # skipcq: PYL-W0221
Expand Down
42 changes: 34 additions & 8 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import zarr
from dask import compute
from dask.diagnostics import ProgressBar
from numcodecs import Pickle
from torch import nn
from typing_extensions import Unpack

Expand All @@ -71,6 +72,8 @@
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.type_hints import IntPair, Resolution, Units

dask.config.set({"dataframe.convert-string": False})


class EngineABCRunParams(TypedDict, total=False):
"""Parameters for configuring the :func:`EngineABC.run()` method.
Expand Down Expand Up @@ -518,7 +521,7 @@ def infer_patches(
coordinates = []

# Main output dictionary
raw_predictions = dict(zip(keys, [[]] * len(keys), strict=False))
raw_predictions = {key: [] for key in keys}

# Inference loop
tqdm = get_tqdm()
Expand Down Expand Up @@ -704,13 +707,29 @@ def save_predictions(
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
write_tasks = []
for key in keys_to_compute:
dask_array = processed_predictions[key].rechunk("auto")
task = dask_array.to_zarr(
url=save_path,
component=key,
compute=False,
)
write_tasks.append(task)
dask_output = processed_predictions[key]
if isinstance(dask_output, da.Array):
dask_output = dask_output.rechunk("auto")
task = dask_output.to_zarr(
url=save_path, component=key, compute=False, object_codec=None
)
write_tasks.append(task)

if isinstance(dask_output, list) and all(
isinstance(dask_array, da.Array) for dask_array in dask_output
):
for i, dask_array in enumerate(dask_output):
object_codec = (
Pickle() if dask_array.dtype == "object" else None
)
task = dask_array.to_zarr(
url=save_path,
component=f"{key}/{i}",
compute=False,
object_codec=object_codec,
)
write_tasks.append(task)

msg = f"Saving output to {save_path}."
logger.info(msg=msg)
with ProgressBar():
Expand Down Expand Up @@ -1174,6 +1193,9 @@ def _update_run_params(
If an unsupported output_type is provided.
ValueError:
If required configuration or input parameters are missing.
ValueError:
If save_dir is not provided and output_type is "zarr"
or "annotationstore".

"""
for key in kwargs:
Expand Down Expand Up @@ -1214,6 +1236,10 @@ def _update_run_params(
)
logger.info(msg)

if save_dir is None and output_type.lower() in ["zarr", "annotationstore"]:
msg = f"Please provide save_dir for output_type={output_type}"
raise ValueError(msg)

self.images = self._validate_images_masks(images=images)

if masks is not None:
Expand Down
Loading
Loading