Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
__pycache__/
docs/

.vscode
.git
.mypy_cache
.ruff_cache
.pytype
.coverage
.coverage.*
.coverage/
coverage.xml
.readthedocs.yml
*.toml

!README.md
24 changes: 24 additions & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ jobs:
matrix:
opt: ["codeformat", "pytype", "mypy"]
steps:
- name: Clean unused tools
run: |
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc /usr/local/.ghcup
sudo docker system prune -f

- uses: actions/checkout@v6
- name: Set up Python 3.9
uses: actions/setup-python@v6
Expand Down Expand Up @@ -129,6 +137,14 @@ jobs:
QUICKTEST: True
shell: bash
steps:
- name: Clean unused tools
run: |
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc /usr/local/.ghcup
sudo docker system prune -f

- uses: actions/checkout@v6
with:
fetch-depth: 0
Expand Down Expand Up @@ -213,6 +229,14 @@ jobs:
build-docs:
runs-on: ubuntu-latest
steps:
- name: Clean unused tools
run: |
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc /usr/local/.ghcup
sudo docker system prune -f

- uses: actions/checkout@v6
- name: Set up Python 3.9
uses: actions/setup-python@v6
Expand Down
89 changes: 89 additions & 0 deletions Dockerfile.slim
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To build with a different base image
# please run `docker build` using the `--build-arg IMAGE=...` flag.
ARG IMAGE=debian:12-slim

FROM ${IMAGE} AS build

ARG TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6 8.9 9.0+PTX"

ENV DEBIAN_FRONTEND=noninteractive
ENV APT_INSTALL="apt install -y --no-install-recommends"

RUN apt update && apt upgrade -y && \
${APT_INSTALL} ca-certificates python3-pip python-is-python3 git wget libopenslide0 unzip python3-dev && \
wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb && \
dpkg -i cuda-keyring_1.1-1_all.deb && \
apt update && \
${APT_INSTALL} cuda-toolkit-12 && \
rm -rf /usr/lib/python*/EXTERNALLY-MANAGED /var/lib/apt/lists/* && \
python -m pip install --upgrade --no-cache-dir pip

# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
RUN if [[ $(uname -m) =~ "aarch64" ]]; then \
CFLAGS="-O3" DISABLE_NUMCODECS_SSE2=true DISABLE_NUMCODECS_AVX2=true python -m pip install numcodecs; \
fi

# NGC Client
WORKDIR /opt/tools
ARG NGC_CLI_URI="https://ngc.nvidia.com/downloads/ngccli_linux.zip"
RUN wget -q ${NGC_CLI_URI} && unzip ngccli_linux.zip && chmod u+x ngc-cli/ngc && \
find ngc-cli/ -type f -exec md5sum {} + | LC_ALL=C sort | md5sum -c ngc-cli.md5 && \
rm -rf ngccli_linux.zip ngc-cli.md5

WORKDIR /opt/monai

# copy relevant parts of repo
COPY requirements.txt requirements-min.txt requirements-dev.txt versioneer.py setup.py setup.cfg pyproject.toml ./
COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md MANIFEST.in runtests.sh ./
COPY tests ./tests
COPY monai ./monai

# install full deps
RUN python -m pip install --no-cache-dir -r requirements-dev.txt

# compile ext
RUN CUDA_HOME=/usr/local/cuda FORCE_CUDA=1 USE_COMPILED=1 BUILD_MONAI=1 python setup.py develop

# recreate the image without the installed CUDA packages then copy the installed MONAI and Python directories
FROM ${IMAGE} AS build2

ENV DEBIAN_FRONTEND=noninteractive
ENV APT_INSTALL="apt install -y --no-install-recommends"

RUN apt update && apt upgrade -y && \
${APT_INSTALL} ca-certificates python3-pip python-is-python3 git libopenslide0 && \
apt clean && \
rm -rf /usr/lib/python*/EXTERNALLY-MANAGED /var/lib/apt/lists/* && \
python -m pip install --upgrade --no-cache-dir pip

COPY --from=build /opt/monai /opt/monai
COPY --from=build /opt/tools /opt/tools
COPY --from=build /usr/local/lib/python3.11/dist-packages /usr/local/lib/python3.11/dist-packages
COPY --from=build /usr/local/bin /usr/local/bin

RUN rm -rf /opt/monai/build /opt/monai/monai.egg-info && \
find / -name __pycache__ | xargs rm -rf

# flatten all layers down to one
FROM ${IMAGE}
LABEL maintainer="monai.contact@gmail.com"

COPY --from=build2 / /

WORKDIR /opt/monai

ENV PATH=${PATH}:/opt/tools:/opt/tools/ngc-cli
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1
ENV CUDA_HOME=/usr/local/cuda
ENV BUILD_MONAI=1
4 changes: 2 additions & 2 deletions monai/apps/vista3d/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def point_based_window_inferer(
for j in range(len(ly_)):
for k in range(len(lz_)):
lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
unravel_slice = [
unravel_slice = (
slice(None),
slice(None),
slice(int(lx), int(rx)),
slice(int(ly), int(ry)),
slice(int(lz), int(rz)),
]
)
batch_image = image[unravel_slice]
output = predictor(
batch_image,
Expand Down
12 changes: 4 additions & 8 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,10 @@ def connected_components_combine(
_logits = logits[mapping_index]
inside = []
for i in range(_logits.shape[0]):
inside.append(
np.any(
[
_logits[i, 0, p[0], p[1], p[2]].item() > 0
for p in point_coords[i].cpu().numpy().round().astype(int)
]
)
)
p_coord = point_coords[i].cpu().numpy().round().astype(int)
inside_p = [_logits[i, 0, p[0], p[1], p[2]].item() > 0 for p in p_coord]
inside.append(int(np.any(inside_p))) # convert to int to avoid typing problems with Numpy

inside_tensor = torch.tensor(inside).to(logits.device)
nan_mask = torch.isnan(_logits)
# _logits are converted to binary [B1, 1, H, W, D]
Expand Down
6 changes: 3 additions & 3 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,13 @@ def convert_to_onnx(
torch_versioned_kwargs = {}
if use_trace:
# let torch.onnx.export to trace the model.
mode_to_export = model
model_to_export = model
torch_versioned_kwargs = kwargs
if "dynamo" in kwargs and kwargs["dynamo"] and verify:
torch_versioned_kwargs["verify"] = verify
verify = False
else:
mode_to_export = torch.jit.script(model, **kwargs)
model_to_export = torch.jit.script(model, **kwargs)

if torch.is_tensor(inputs) or isinstance(inputs, dict):
onnx_inputs = (inputs,)
Expand All @@ -733,7 +733,7 @@ def convert_to_onnx(
f = filename
print(f"torch_versioned_kwargs={torch_versioned_kwargs}")
torch.onnx.export(
mode_to_export,
model_to_export,
onnx_inputs,
f=f,
input_names=input_names,
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Full requirements for developments
-r requirements-min.txt
pytorch-ignite==0.4.11
pytorch-ignite
gdown>=4.7.3
scipy>=1.12.0; python_version >= '3.9'
itk>=5.2
Expand Down Expand Up @@ -52,8 +52,8 @@ nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "a
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
onnx>=1.13.0
onnxruntime
onnxscript
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch>=2.4.1; platform_system != "Windows"
torch>=2.4.1, !=2.7.0; platform_system == "Windows"
torch>=2.4.1, <2.9; platform_system != "Windows"
torch>=2.4.1, <2.9, !=2.7.0; platform_system == "Windows"
numpy>=1.24,<3.0
3 changes: 2 additions & 1 deletion tests/bundle/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import tempfile
import unittest
from unittest.case import skipUnless
from unittest.case import skipIf, skipUnless
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -219,6 +219,7 @@ def test_monaihosting_url_download_bundle(self, bundle_files, bundle_name, url):

@parameterized.expand([TEST_CASE_5])
@skip_if_quick
@skipIf(os.getenv("NGC_API_KEY", None) is None, "NGC API key required for this test")
def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _url):
with skip_if_downloading_fails():
# download a single file from url, also use `args_file`
Expand Down
2 changes: 1 addition & 1 deletion tests/data/meta_tensor/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_pickling(self):
with tempfile.TemporaryDirectory() as tmp_dir:
fname = os.path.join(tmp_dir, "im.pt")
torch.save(m, fname)
m2 = torch.load(fname, weights_only=True)
m2 = torch.load(fname, weights_only=False)
self.check(m2, m, ids=False)

@skip_if_no_cuda
Expand Down
2 changes: 1 addition & 1 deletion tests/losses/test_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TestMultiScale(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = MultiScaleLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)

@parameterized.expand(
[
Expand Down
44 changes: 24 additions & 20 deletions tests/networks/test_convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
from monai.networks.nets import SegResNet, UNet
from tests.test_utils import SkipIfNoModule, optional_import, skip_if_quick

if torch.cuda.is_available():
TORCH_DEVICE_OPTIONS = ["cpu", "cuda"]
else:
TORCH_DEVICE_OPTIONS = ["cpu"]
onnx, _ = optional_import("onnx")

TORCH_DEVICE_OPTIONS = ["cpu"]

# FIXME: CUDA seems to produce different model outputs during testing vs. ONNX outputs, use CPU only for now
# if torch.cuda.is_available():
# TORCH_DEVICE_OPTIONS.append("cuda")

TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))

Expand All @@ -35,38 +39,38 @@
else:
rtol, atol = 1e-3, 1e-4

onnx, _ = optional_import("onnx")


@SkipIfNoModule("onnx")
@skip_if_quick
class TestConvertToOnnx(unittest.TestCase):
@parameterized.expand(TESTS)
def test_unet(self, device, use_trace, use_ort):
"""Test converting UNet to ONNX."""
if use_ort:
_, has_onnxruntime = optional_import("onnxruntime")
if not has_onnxruntime:
self.skipTest("onnxruntime is not installed probably due to python version >= 3.11.")
model = UNet(
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
)
if use_trace:
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
verify=True,
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
verify=True,
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

@parameterized.expand(TESTS_ORT)
def test_seg_res_net(self, device, use_ort):
"""Test converting SetResNet to ONNX."""
if use_ort:
_, has_onnxruntime = optional_import("onnxruntime")
if not has_onnxruntime:
Expand Down
Loading