diff --git a/mipcandy/__init__.py b/mipcandy/__init__.py index f9bafee..411b77a 100644 --- a/mipcandy/__init__.py +++ b/mipcandy/__init__.py @@ -4,7 +4,7 @@ from mipcandy.data import * from mipcandy.evaluation import EvalCase, EvalResult, Evaluator from mipcandy.frontend import * -from mipcandy.inference import parse_predictant, Predictor +from mipcandy.inference import parse_predictant, Predictor, StreamPredictor from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \ WithNetwork from mipcandy.metrics import do_reduction, binary_dice, dice_similarity_coefficient, soft_dice diff --git a/mipcandy/inference.py b/mipcandy/inference.py index b493ed1..2aab385 100644 --- a/mipcandy/inference.py +++ b/mipcandy/inference.py @@ -1,8 +1,9 @@ from abc import ABCMeta +from collections.abc import Generator from math import log, ceil from os import PathLike, listdir from os.path import isdir, basename, exists -from typing import Sequence +from typing import Sequence, override import torch from torch import nn @@ -102,3 +103,59 @@ def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: return self.predict(x) + + +class StreamPredictor(Predictor): + def _stream(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[ + tuple[torch.Tensor, str | None], None, None]: + if isinstance(x, PathBasedUnsupervisedDataset): + for case, path in zip(x, x.paths()): + yield self.predict_image(case), path + return + if isinstance(x, UnsupervisedDataset): + for case in x: + yield self.predict_image(case), None + return + if isinstance(x, str): + if isdir(x): + for case in listdir(x): + yield self.predict_image(Loader.do_load(f"{x}/{case}")), case + else: + yield self.predict_image(Loader.do_load(x)), basename(x) + return + if isinstance(x, torch.Tensor): + yield self.predict_image(x), None + return + for case in x: + if isinstance(case, str): + yield self.predict_image(Loader.do_load(case)), case[case.rfind("/") + 1:] + elif isinstance(case, torch.Tensor): + yield self.predict_image(case), None + else: + raise TypeError(f"Unexpected type of element {type(case)}") + + @override + def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[torch.Tensor, None, None]: + for output, _ in self._stream(x): + yield output + + @override + def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, + folder: str | PathLike[str]) -> list[str] | None: + if not exists(folder): + raise FileNotFoundError(f"Folder {folder} does not exist") + result: list[str] | None = None + for i, (output, name) in enumerate(self._stream(x)): + if name is not None: + if result is None: + result = [] + result.append(name) + else: + ext = "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha" + name = f"prediction_{i}.{ext}" + self.save_prediction(output, f"{folder}/{name}") + return result + + @override + def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[torch.Tensor, None, None]: + return self.predict(x)