From cdf9da92c7053ad44e4419fdbe47ad254de40cca Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Mon, 23 Feb 2026 23:24:14 -0500 Subject: [PATCH 1/2] Streamline `Predictor` to use lazy evaluation for memory efficiency (#228) --- mipcandy/inference.py | 49 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/mipcandy/inference.py b/mipcandy/inference.py index b493ed1..16c5be5 100644 --- a/mipcandy/inference.py +++ b/mipcandy/inference.py @@ -1,4 +1,5 @@ from abc import ABCMeta +from collections.abc import Generator from math import log, ceil from os import PathLike, listdir from os.path import isdir, basename, exists @@ -68,16 +69,36 @@ def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Te output = restoring_module(output) return output if batch else output.squeeze(0) - def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> tuple[list[torch.Tensor], list[str] | None]: + def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[ + tuple[torch.Tensor, str | None], None, None]: if isinstance(x, PathBasedUnsupervisedDataset): - return [self.predict_image(case) for case in x], x.paths() + for case, path in zip(x, x.paths()): + yield self.predict_image(case), path + return if isinstance(x, UnsupervisedDataset): - return [self.predict_image(case) for case in x], None - images, filenames = parse_predictant(x, Loader) - return [self.predict_image(image) for image in images], filenames + for case in x: + yield self.predict_image(case), None + return + if isinstance(x, str): + if isdir(x): + for case in listdir(x): + yield self.predict_image(Loader.do_load(f"{x}/{case}")), case + else: + yield self.predict_image(Loader.do_load(x)), basename(x) + return + if isinstance(x, torch.Tensor): + yield self.predict_image(x), None + return + for case in x: + if isinstance(case, str): + yield self.predict_image(Loader.do_load(case)), case[case.rfind("/") + 1:] + elif isinstance(case, torch.Tensor): + yield self.predict_image(case), None + else: + raise TypeError(f"Unexpected type of element {type(case)}") def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: - return self._predict(x)[0] + return [output for output, _ in self._predict(x)] @staticmethod def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None: @@ -96,9 +117,19 @@ def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLi def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, folder: str | PathLike[str]) -> list[str] | None: - outputs, filenames = self._predict(x) - self.save_predictions(outputs, folder, filenames=filenames) - return filenames + if not exists(folder): + raise FileNotFoundError(f"Folder {folder} does not exist") + result: list[str] | None = None + for i, (output, name) in enumerate(self._predict(x)): + if name is not None: + if result is None: + result = [] + result.append(name) + else: + ext = "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha" + name = f"prediction_{i}.{ext}" + self.save_prediction(output, f"{folder}/{name}") + return result def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: return self.predict(x) From a6645a35f23e8ad82a4fca55d27b544abfed3252 Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Sun, 15 Mar 2026 14:39:52 -0400 Subject: [PATCH 2/2] Add `StreamPredictor` for memory-efficient lazy evaluation (#228) Instead of modifying `Predictor` directly, introduce a new `StreamPredictor` subclass that uses generator-based streaming to avoid holding all predictions in memory simultaneously. --- mipcandy/__init__.py | 2 +- mipcandy/inference.py | 68 ++++++++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/mipcandy/__init__.py b/mipcandy/__init__.py index f9bafee..411b77a 100644 --- a/mipcandy/__init__.py +++ b/mipcandy/__init__.py @@ -4,7 +4,7 @@ from mipcandy.data import * from mipcandy.evaluation import EvalCase, EvalResult, Evaluator from mipcandy.frontend import * -from mipcandy.inference import parse_predictant, Predictor +from mipcandy.inference import parse_predictant, Predictor, StreamPredictor from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \ WithNetwork from mipcandy.metrics import do_reduction, binary_dice, dice_similarity_coefficient, soft_dice diff --git a/mipcandy/inference.py b/mipcandy/inference.py index 16c5be5..2aab385 100644 --- a/mipcandy/inference.py +++ b/mipcandy/inference.py @@ -3,7 +3,7 @@ from math import log, ceil from os import PathLike, listdir from os.path import isdir, basename, exists -from typing import Sequence +from typing import Sequence, override import torch from torch import nn @@ -69,7 +69,44 @@ def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Te output = restoring_module(output) return output if batch else output.squeeze(0) - def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[ + def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> tuple[list[torch.Tensor], list[str] | None]: + if isinstance(x, PathBasedUnsupervisedDataset): + return [self.predict_image(case) for case in x], x.paths() + if isinstance(x, UnsupervisedDataset): + return [self.predict_image(case) for case in x], None + images, filenames = parse_predictant(x, Loader) + return [self.predict_image(image) for image in images], filenames + + def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: + return self._predict(x)[0] + + @staticmethod + def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None: + save_image(output, path) + + def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLike[str], *, + filenames: Sequence[str | PathLike[str]] | None = None) -> None: + if not exists(folder): + raise FileNotFoundError(f"Folder {folder} does not exist") + if not filenames: + num_digits = ceil(log(len(outputs))) + filenames = [f"prediction_{str(i).zfill(num_digits)}.{ + "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha"}" for i, output in enumerate(outputs)] + for i, prediction in enumerate(outputs): + self.save_prediction(prediction, f"{folder}/{filenames[i]}") + + def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, + folder: str | PathLike[str]) -> list[str] | None: + outputs, filenames = self._predict(x) + self.save_predictions(outputs, folder, filenames=filenames) + return filenames + + def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: + return self.predict(x) + + +class StreamPredictor(Predictor): + def _stream(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[ tuple[torch.Tensor, str | None], None, None]: if isinstance(x, PathBasedUnsupervisedDataset): for case, path in zip(x, x.paths()): @@ -97,30 +134,18 @@ def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[ else: raise TypeError(f"Unexpected type of element {type(case)}") - def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: - return [output for output, _ in self._predict(x)] - - @staticmethod - def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None: - save_image(output, path) - - def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLike[str], *, - filenames: Sequence[str | PathLike[str]] | None = None) -> None: - if not exists(folder): - raise FileNotFoundError(f"Folder {folder} does not exist") - if not filenames: - num_digits = ceil(log(len(outputs))) - filenames = [f"prediction_{str(i).zfill(num_digits)}.{ - "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha"}" for i, output in enumerate(outputs)] - for i, prediction in enumerate(outputs): - self.save_prediction(prediction, f"{folder}/{filenames[i]}") + @override + def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[torch.Tensor, None, None]: + for output, _ in self._stream(x): + yield output + @override def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, folder: str | PathLike[str]) -> list[str] | None: if not exists(folder): raise FileNotFoundError(f"Folder {folder} does not exist") result: list[str] | None = None - for i, (output, name) in enumerate(self._predict(x)): + for i, (output, name) in enumerate(self._stream(x)): if name is not None: if result is None: result = [] @@ -131,5 +156,6 @@ def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, self.save_prediction(output, f"{folder}/{name}") return result - def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: + @override + def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[torch.Tensor, None, None]: return self.predict(x)