Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mipcandy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mipcandy.data import *
from mipcandy.evaluation import EvalCase, EvalResult, Evaluator
from mipcandy.frontend import *
from mipcandy.inference import parse_predictant, Predictor
from mipcandy.inference import parse_predictant, Predictor, StreamPredictor
from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \
WithNetwork
from mipcandy.metrics import do_reduction, binary_dice, dice_similarity_coefficient, soft_dice
Expand Down
59 changes: 58 additions & 1 deletion mipcandy/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABCMeta
from collections.abc import Generator
from math import log, ceil
from os import PathLike, listdir
from os.path import isdir, basename, exists
from typing import Sequence
from typing import Sequence, override

import torch
from torch import nn
Expand Down Expand Up @@ -102,3 +103,59 @@ def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset,

def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]:
return self.predict(x)


class StreamPredictor(Predictor):
def _stream(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[
tuple[torch.Tensor, str | None], None, None]:
if isinstance(x, PathBasedUnsupervisedDataset):
for case, path in zip(x, x.paths()):
yield self.predict_image(case), path
return
if isinstance(x, UnsupervisedDataset):
for case in x:
yield self.predict_image(case), None
return
if isinstance(x, str):
if isdir(x):
for case in listdir(x):
yield self.predict_image(Loader.do_load(f"{x}/{case}")), case
else:
yield self.predict_image(Loader.do_load(x)), basename(x)
return
if isinstance(x, torch.Tensor):
yield self.predict_image(x), None
return
for case in x:
if isinstance(case, str):
yield self.predict_image(Loader.do_load(case)), case[case.rfind("/") + 1:]
elif isinstance(case, torch.Tensor):
yield self.predict_image(case), None
else:
raise TypeError(f"Unexpected type of element {type(case)}")

@override
def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[torch.Tensor, None, None]:
for output, _ in self._stream(x):
yield output

@override
def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset,
folder: str | PathLike[str]) -> list[str] | None:
if not exists(folder):
raise FileNotFoundError(f"Folder {folder} does not exist")
result: list[str] | None = None
for i, (output, name) in enumerate(self._stream(x)):
if name is not None:
if result is None:
result = []
result.append(name)
else:
ext = "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha"
name = f"prediction_{i}.{ext}"
self.save_prediction(output, f"{folder}/{name}")
return result

@override
def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> Generator[torch.Tensor, None, None]:
return self.predict(x)