diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst index 56d660e35c..61c94725d9 100644 --- a/docs/source/mb_specification.rst +++ b/docs/source/mb_specification.rst @@ -6,9 +6,9 @@ MONAI Bundle Specification Overview ======== -This is the specification for the MONAI Bundle (MB) format of portable described deep learning models. The objective of a MB is to define a packaged network or model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a single network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. +This is the specification for the MONAI Bundle (MB) format of portable deep learning models. The objective of a MB is to define a packaged network or model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a single network as a pickled state dictionary plus optionally an exported program (``.pt2``, via ``torch.export``) and/or an ONNX object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. -This specification defines the directory structure a bundle must have and the necessary files it must contain. Additional files may be included and the directory packaged into a zip file or included as extra files directly in a Torchscript file. +This specification defines the directory structure a bundle must have and the necessary files it must contain. Additional files may be included and the directory packaged into a zip file or included as extra files directly in the exported archive. Directory Structure =================== @@ -23,7 +23,8 @@ A MONAI Bundle is defined primarily as a directory with a set of specifically na ┃ ┗━ metadata.json ┣━ models ┃ ┣━ model.pt - ┃ ┣━ *model.ts + ┃ ┣━ *model.pt2 + ┃ ┣━ *model.ts (deprecated) ┃ ┗━ *model.onnx ┗━ docs ┣━ *README.md @@ -38,7 +39,8 @@ The following files are **required** to be present with the given filenames for The following files are optional but must have these names in the directory given above: -* **model.ts**: the Torchscript saved model if the model is compatible with being saved correctly in this format. +* **model.pt2**: the ``torch.export`` exported program if the model is compatible with being exported in this format. This is the preferred format for model deployment. +* **model.ts**: the TorchScript saved model (deprecated since v1.5, will be removed in v1.7; use ``model.pt2`` instead). * **model.onnx**: the ONNX model if the model is compatible with being saved correctly in this format. * **README.md**: plain-language information on the model, how to use it, author information, etc. in Markdown format. * **license.txt**: software license attached to the data, can be left blank if no license needed. @@ -50,9 +52,13 @@ Archive Format The bundle directory and its contents can be compressed into a zip file to constitute a single file package. When unzipped into a directory this file will reproduce the above directory structure, and should itself also be named after the model it contains. For example, `ModelName.zip` would contain at least `ModelName/configs/metadata.json` and `ModelName/models/model.pt`, thus when unzipped would place files into the directory `ModelName` rather than into the current working directory. -The Torchscript file format is also just a zip file with a specific structure. When creating such an archive with `save_net_with_metadata` a MB-compliant Torchscript file can be created by including the contents of `metadata.json` as the `meta_values` argument of the function, and other files included as `more_extra_files` entries. These will be stored in a `extras` directory in the zip file and can be retrieved with `load_net_with_metadata` or with any other library/tool that can read zip data. In this format the `model.*` files are obviously not needed but `README.md` and `license.txt` as well as any others provided can be added as more extra files. +The ``.pt2`` file format (produced by ``torch.export``) is also a zip file with a specific structure. When creating such an archive with ``save_exported_program`` a MB-compliant exported program file can be created by including the contents of ``metadata.json`` as the ``meta_values`` argument of the function, and other files included as ``more_extra_files`` entries. These will be stored in the zip file and can be retrieved with ``load_exported_program`` or with any other library/tool that can read zip data. In this format the ``model.*`` files are obviously not needed but ``README.md`` and ``license.txt`` as well as any others provided can be added as more extra files. -The `bundle` submodule of MONAI contains a number of command line programs. To produce a Torchscript bundle use `ckpt_export` with a set of specified components such as the saved weights file and metadata file. Config files can be provided as JSON or YAML dictionaries defining Python constructs used by the `ConfigParser`, however regardless of format the produced bundle Torchscript object will store the files as JSON. +The ``bundle`` submodule of MONAI contains a number of command line programs. To produce an exported bundle use ``export_checkpoint`` with a set of specified components such as the saved weights file and metadata file. Config files can be provided as JSON or YAML dictionaries defining Python constructs used by the ``ConfigParser``, however regardless of format the produced bundle archive will store the files as JSON. + +.. note:: + + The legacy TorchScript (``ckpt_export``, ``save_net_with_metadata``, ``load_net_with_metadata``) workflow is deprecated since v1.5 and will be removed in v1.7. Use ``export_checkpoint``, ``save_exported_program``, and ``load_exported_program`` instead. metadata.json File ================== diff --git a/docs/source/modules.md b/docs/source/modules.md index b2e95658bf..3c0ed09161 100644 --- a/docs/source/modules.md +++ b/docs/source/modules.md @@ -237,7 +237,7 @@ and [MLflow](https://github.com/Project-MONAI/tutorials/blob/main/experiment_man The objective of a MONAI bundle is to define a packaged model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a -single network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON +single network as a pickled state dictionary plus optionally an exported program (`.pt2`, via `torch.export`) and/or an ONNX object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. More details are available at [bundle specification](https://monai.readthedocs.io/en/latest/mb_specification.html). diff --git a/monai/apps/detection/networks/retinanet_network.py b/monai/apps/detection/networks/retinanet_network.py index f1535f9e8d..a27c269209 100644 --- a/monai/apps/detection/networks/retinanet_network.py +++ b/monai/apps/detection/networks/retinanet_network.py @@ -332,7 +332,7 @@ def forward(self, images: Tensor) -> Any: features = self.feature_extractor(images) if isinstance(features, Tensor): feature_maps = [features] - elif torch.jit.isinstance(features, dict[str, Tensor]): + elif isinstance(features, dict): feature_maps = list(features.values()) else: feature_maps = list(features) diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index 20f6fc6025..83497e747b 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -257,7 +257,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) for axis in range(self.spatial_dims) ] - # to support torchscript, cannot directly use torch.meshgrid(shifts_centers). + # unpack before passing to torch.meshgrid for compatibility. shifts_centers = list(torch.meshgrid(shifts_centers[: self.spatial_dims], indexing="ij")) for axis in range(self.spatial_dims): diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 1fc0690cc9..505d8e98dc 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -205,7 +205,8 @@ def load_from_mmar( mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`. progress: whether to display a progress bar when downloading the content. version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`. - map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. + map_location: pytorch API parameter for ``torch.load`` or ``torch.jit.load`` (legacy ``.ts`` files). + Ignored when loading ``.pt2`` (ExportedProgram) files. pretrained: whether to load the pretrained weights after initializing a network module. weights_only: whether to load only the weights instead of initializing the network module and assign weights. model_key: a key to search in the model file or config file for the model dictionary. @@ -232,12 +233,26 @@ def load_from_mmar( _model_file = model_dir / item.get(Keys.MODEL_FILE, model_file) logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.') - # loading with `torch.jit.load` + # loading with `torch.export.load` for .pt2 files + if _model_file.name.endswith(".pt2"): + if not pretrained: + warnings.warn("Loading an ExportedProgram, 'pretrained' option ignored.", stacklevel=2) + if weights_only: + warnings.warn("Loading an ExportedProgram, 'weights_only' option ignored.", stacklevel=2) + return torch.export.load(str(_model_file)) + + # loading with `torch.jit.load` for legacy .ts files if _model_file.name.endswith(".ts"): + warnings.warn( + "Loading TorchScript (.ts) models is deprecated since MONAI v1.5 and will be removed in v1.7. " + "Use torch.export (.pt2) format instead.", + FutureWarning, + stacklevel=2, + ) if not pretrained: - warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.") + warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.", stacklevel=2) if weights_only: - warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.") + warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.", stacklevel=2) return torch.jit.load(_model_file, map_location=map_location) # loading with `torch.load` diff --git a/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py b/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py index 91a9f3d8d3..3fd72f53e5 100644 --- a/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py +++ b/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py @@ -85,32 +85,40 @@ def __init__( self.spatial_dims = spatial_dims self.coil_dim = coil_dim - def get_fully_sampled_region(self, mask: Tensor) -> tuple[int, int]: + def _compute_acr_mask(self, mask: Tensor) -> Tensor: """ - Extracts the size of the fully-sampled part of the kspace. Note that when a kspace - is under-sampled, a part of its center is fully sampled. This part is called the Auto - Calibration Region (ACR). ACR is used for sensitivity map computation. + Compute a boolean mask for the Auto Calibration Region (ACR) — the contiguous + fully-sampled center of the k-space sampling mask. + + Uses pure tensor operations (``cumprod``) instead of while-loops so that + the computation is compatible with ``torch.export``. Args: - mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension + mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension. Returns: - A tuple containing - (1) left index of the region - (2) right index of the region - - Note: - Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right - indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12. + A boolean tensor broadcastable to ``masked_kspace`` that is True inside the ACR. """ - left = right = mask.shape[-2] // 2 - while mask[..., right, :]: - right += 1 + s_len = mask.shape[-2] + center = s_len // 2 + + # Flatten to 1-D along the sampling axis + m = mask.reshape(-1)[:s_len].bool() + + # Count consecutive True values from center going right + right_count = torch.cumprod(m[center:].int(), dim=0).sum() + # Count consecutive True values from center going left (including center) + left_count = torch.cumprod(m[: center + 1].flip(0).int(), dim=0).sum() + num_low_freqs = left_count + right_count - 1 - while mask[..., left, :]: - left -= 1 + # Build a boolean mask over the sampling dimension + start = (s_len - num_low_freqs + 1) // 2 + freq_idx = torch.arange(s_len, device=mask.device) + acr_1d = (freq_idx >= start) & (freq_idx < start + num_low_freqs) - return left + 1, right + # Reshape to (..., S, 1) so it broadcasts against masked_kspace + result: Tensor = acr_1d.view(*([1] * (mask.ndim - 2)), s_len, 1) + return result def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor: """ @@ -122,13 +130,10 @@ def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor: Returns: predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data. """ - left, right = self.get_fully_sampled_region(mask) - num_low_freqs = right - left # size of the fully-sampled center + acr_mask = self._compute_acr_mask(mask) # take out the fully-sampled region and set the rest of the data to zero - x = torch.zeros_like(masked_kspace) - start = (mask.shape[-2] - num_low_freqs + 1) // 2 # this marks the start of center extraction - x[..., start : start + num_low_freqs, :] = masked_kspace[..., start : start + num_low_freqs, :] + x = masked_kspace * acr_mask # apply inverse fourier to the extracted fully-sampled data x = ifftn_centered_t(x, spatial_dims=self.spatial_dims, is_complex=True) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 3f3c8d545e..194d3deacd 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -20,6 +20,7 @@ create_workflow, download, download_large_files, + export_checkpoint, get_all_bundles_list, get_bundle_info, get_bundle_versions, diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 778c9ef2f0..9593296978 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -15,6 +15,7 @@ ckpt_export, download, download_large_files, + export_checkpoint, init_bundle, onnx_export, run, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 9fdee6acd0..a2392fb038 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -35,7 +35,9 @@ from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import PathLike from monai.data import load_net_with_metadata, save_net_with_metadata +from monai.data.export_utils import load_exported_program, save_exported_program from monai.networks import ( + convert_to_export, convert_to_onnx, convert_to_torchscript, convert_to_trt, @@ -46,6 +48,7 @@ from monai.utils import ( IgniteInfo, check_parent_dir, + deprecated, ensure_tuple, get_equivalent_dtype, min_version, @@ -632,6 +635,7 @@ def load( workflow_type: str = "train", model_file: str | None = None, load_ts_module: bool = False, + load_exported_module: bool = False, bundle_dir: PathLike | None = None, source: str = DEFAULT_DOWNLOAD_SOURCE, repo: str | None = None, @@ -646,7 +650,7 @@ def load( net_override: dict | None = None, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ - Load model weights or TorchScript module of a bundle. + Load model weights, TorchScript module, or exported program of a bundle. Args: name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. @@ -664,10 +668,16 @@ def load( or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `train` for training workflow. - model_file: the relative path of the model weights or TorchScript module within bundle. - If `None`, "models/model.pt" or "models/model.ts" will be used. + model_file: the relative path of the model weights or exported module within bundle. + If `None`, "models/model.pt", "models/model.ts", or "models/model.pt2" will be used + depending on the loading mode. load_ts_module: a flag to specify if loading the TorchScript module. - bundle_dir: directory the weights/TorchScript module will be loaded from. + + .. deprecated:: 1.5 + Use ``load_exported_module=True`` instead. + + load_exported_module: a flag to specify if loading a ``torch.export`` ``.pt2`` module. + bundle_dir: directory the weights/module will be loaded from. Default is `bundle` subfolder under `torch.hub.get_dir()`. source: storage location name. This argument is used when `model_file` is not existing locally and need to be downloaded first. @@ -684,32 +694,51 @@ def load( device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model weights. if not nested checkpoint, no need to set. - config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, - see `_extra_files` in `torch.jit.load` for more details. + config_files: extra filenames would be loaded. The argument only works when loading a TorchScript + or exported module, see ``_extra_files`` in ``torch.jit.load`` / ``torch.export.load`` for details. workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". args_file: a JSON or YAML file to provide default values for all the args in "download" function. copy_model_args: other arguments for the `monai.networks.copy_model_state` function. net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`. Returns: - 1. If `load_ts_module` is `False` and `model` is `None`, + 1. If ``load_ts_module`` and ``load_exported_module`` are both ``False`` and ``model`` is ``None``, return model weights if can't find "network_def" in the bundle, else return an instantiated network that loaded the weights. - 2. If `load_ts_module` is `False` and `model` is not `None`, + 2. If ``load_ts_module`` and ``load_exported_module`` are both ``False`` and ``model`` is not ``None``, return an instantiated network that loaded the weights. - 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, + 3. If ``load_ts_module`` is ``True``, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. - please check `monai.data.load_net_with_metadata` for more details. + please check ``monai.data.load_net_with_metadata`` for more details. + 4. If ``load_exported_module`` is ``True``, return a triple of + (ExportedProgram, metadata dict, extra files dict). + See :func:`monai.data.load_exported_program` for more details. """ bundle_dir_ = _process_bundle_dir(bundle_dir) net_override = {} if net_override is None else net_override copy_model_args = {} if copy_model_args is None else copy_model_args + if load_ts_module and load_exported_module: + raise ValueError("load_ts_module and load_exported_module are mutually exclusive.") + + if load_ts_module: + warnings.warn( + "load_ts_module is deprecated since v1.5 and will be removed in v1.7. " + "Use load_exported_module=True instead.", + FutureWarning, + stacklevel=2, + ) + if device is None: device = "cuda:0" if is_available() else "cpu" if model_file is None: - model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + if load_exported_module: + model_file = os.path.join("models", "model.pt2") + elif load_ts_module: + model_file = os.path.join("models", "model.ts") + else: + model_file = os.path.join("models", "model.pt") if source == "ngc": name = _add_ngc_prefix(name) if remove_prefix: @@ -727,14 +756,25 @@ def load( args_file=args_file, ) + # loading with `torch.export.load` + if load_exported_module: + return load_exported_program(full_path, more_extra_files=config_files or ()) # loading with `torch.jit.load` if load_ts_module is True: - return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) + # Suppress the @deprecated warning from load_net_with_metadata since the user + # already received a FutureWarning about load_ts_module above. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning, message=".*load_net_with_metadata.*") + return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) # loading with `torch.load` model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True) if not isinstance(model_dict, Mapping): - warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") + warnings.warn( + f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.", + category=UserWarning, + stacklevel=2, + ) model_dict = get_state_dict(model_dict) _workflow = None @@ -750,11 +790,13 @@ def load( **_net_override, ) else: - warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.") + warnings.warn( + f"Cannot find the config file: {bundle_config_file}, return state dict instead.", stacklevel=2 + ) return model_dict if _workflow is not None: if not hasattr(_workflow, "network_def"): - warnings.warn("No available network definition in the bundle, return state dict instead.") + warnings.warn("No available network definition in the bundle, return state dict instead.", stacklevel=2) return model_dict else: model = _workflow.network_def @@ -1277,7 +1319,7 @@ def _export( (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. - filepath: filepath to export, if filename has no extension, it becomes `.ts`. + filepath: filepath to export. ckpt_file: filepath of the model checkpoint to load. config_file: filepath of the config file to save in the converted model,the saved key in the converted model is the config filename without extension, and the saved config value is always serialized in @@ -1434,6 +1476,7 @@ def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> N ) +@deprecated(since="1.5", removed="1.7", msg_suffix="Use export_checkpoint() instead.") def ckpt_export( net_id: str | None = None, filepath: PathLike | None = None, @@ -1568,6 +1611,140 @@ def ckpt_export( ) +def export_checkpoint( + net_id: str | None = None, + filepath: PathLike | None = None, + ckpt_file: str | None = None, + meta_file: str | Sequence[str] | None = None, + config_file: str | Sequence[str] | None = None, + key_in_ckpt: str | None = None, + input_shape: Sequence[int] | None = None, + dynamic_shapes: dict | tuple | None = None, + args_file: str | None = None, + converter_kwargs: Mapping | None = None, + **override: Any, +) -> None: + """ + Export the model checkpoint to a ``.pt2`` file using :func:`torch.export.export`, with metadata and + config included. + + Typical usage examples: + + .. code-block:: bash + + python -m monai.bundle export_checkpoint network --filepath --ckpt_file ... + + Args: + net_id: ID name of the network component in the config, it must be ``torch.nn.Module``. + Default to ``"network_def"``. + filepath: filepath to export. If filename has no extension it becomes ``.pt2``. + Default to ``"models/model.pt2"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified. + ckpt_file: filepath of the model checkpoint to load. + Default to ``"models/model.pt"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified. + meta_file: filepath of the metadata file. If it is a list of file paths, contents will be merged. + Default to ``"configs/metadata.json"`` under ``"os.getcwd()"`` if ``bundle_root`` is not specified. + config_file: filepath of the config file to save in the exported model. The saved key is the + config filename without extension; the value is always serialized in JSON format. + It can be a single file or a list of files. If ``None``, must be provided in ``args_file``. + key_in_ckpt: for nested checkpoints like ``{"model": XXX, "optimizer": XXX, ...}``, specify the + key of model weights. If not nested, no need to set. + input_shape: a shape used to generate random input for the network, e.g. ``[N, C, H, W]`` or + ``[N, C, H, W, D]``. If not given, will try to parse from ``metadata``. + dynamic_shapes: dynamic shape specifications passed to :func:`torch.export.export`. + args_file: a JSON or YAML file to provide default values for all the parameters. + converter_kwargs: extra arguments for :func:`~monai.networks.utils.convert_to_export`, + except ones that already exist in the input parameters. + override: id-value pairs to override or add the corresponding config content. + """ + _args = update_kwargs( + args=args_file, + net_id=net_id, + filepath=filepath, + meta_file=meta_file, + config_file=config_file, + ckpt_file=ckpt_file, + key_in_ckpt=key_in_ckpt, + input_shape=input_shape, + dynamic_shapes=dynamic_shapes, + converter_kwargs=converter_kwargs, + **override, + ) + _log_input_summary(tag="export_checkpoint", args=_args) + ( + config_file_, + filepath_, + ckpt_file_, + net_id_, + meta_file_, + key_in_ckpt_, + input_shape_, + dynamic_shapes_, + converter_kwargs_, + ) = _pop_args( + _args, + "config_file", + filepath=None, + ckpt_file=None, + net_id=None, + meta_file=None, + key_in_ckpt="", + input_shape=None, + dynamic_shapes=None, + converter_kwargs={}, + ) + bundle_root = _args.get("bundle_root", os.getcwd()) + + parser = ConfigParser() + parser.read_config(f=config_file_) + meta_file_ = os.path.join(bundle_root, "configs", "metadata.json") if meta_file_ is None else meta_file_ + for mf in ensure_tuple(meta_file_): + if os.path.exists(mf): + parser.read_meta(f=mf) + + for k, v in _args.items(): + parser[k] = v + + filepath_ = os.path.join(bundle_root, "models", "model.pt2") if filepath_ is None else filepath_ + ckpt_file_ = os.path.join(bundle_root, "models", "model.pt") if ckpt_file_ is None else ckpt_file_ + if not os.path.exists(ckpt_file_): + raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".') + + net_id_ = "network_def" if net_id_ is None else net_id_ + try: + parser.get_parsed_content(net_id_) + except ValueError as e: + raise ValueError( + f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".' + ) from e + + if not input_shape_: + input_shape_ = _get_fake_input_shape(parser=parser) + + if not input_shape_: + raise ValueError( + "Cannot determine input shape automatically. " + "Please provide it explicitly via the 'input_shape' argument." + ) + + inputs_: Sequence[Any] = [torch.rand(input_shape_)] + + converter_kwargs_.update({"inputs": inputs_, "dynamic_shapes": dynamic_shapes_}) + + save_ep = partial(save_exported_program, include_config_vals=False, append_timestamp=False) + + _export( + convert_to_export, + save_ep, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) + + def trt_export( net_id: str | None = None, filepath: PathLike | None = None, @@ -1588,20 +1765,19 @@ def trt_export( **override: Any, ) -> None: """ - Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript. + Export the model checkpoint to the given filepath as a TensorRT engine. Currently, this API only supports converting models whose inputs are all tensors. Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. Review the TensorRT Support Matrix for which GPUs are supported. There are two ways to export a model: - 1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript. - 2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine ---> - TensorRT engine-based TorchScript. + 1, Torch-TensorRT way: PyTorch module ---> TensorRT engine (via ``torch.export`` on PyTorch >= 2.9, + or via TorchScript on older versions). + 2, ONNX-TensorRT way: PyTorch module ---> ONNX model ---> TensorRT engine. When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through - the second way, some Python data structures like `dict` are not supported. And some TorchScript models are - not supported by the ONNX if exported through `torch.jit.script`. + the second way, some Python data structures like ``dict`` are not supported. Typical usage examples: @@ -1624,8 +1800,8 @@ def trt_export( precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'. input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or [N, C, H, W, D]. If not given, will try to parse from the `metadata` config. - use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to - a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True). + use_trace: whether using ``torch.jit.trace`` to convert the PyTorch model to a TorchScript model + (only used on PyTorch < 2.9 when ``use_onnx`` is ``False``; on 2.9+ ``torch.export`` is used instead). dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of model input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize @@ -1729,11 +1905,18 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) - save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + def _save_trt_model(trt_obj, filepath, **kwargs): + """Save TRT model, using the appropriate format for dynamo vs JIT objects.""" + if isinstance(trt_obj, torch.export.ExportedProgram): + save_exported_program(trt_obj, filepath, include_config_vals=False, append_timestamp=False, **kwargs) + else: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning, message=".*save_net_with_metadata.*") + save_net_with_metadata(trt_obj, filepath, include_config_vals=False, append_timestamp=False, **kwargs) _export( convert_to_trt, - save_ts, + _save_trt_model, parser, net_id=net_id_, filepath=filepath_, diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 5e367cc297..91cb74982e 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -47,6 +47,7 @@ load_decathlon_datalist, load_decathlon_properties, ) +from .export_utils import load_exported_program, save_exported_program from .folder_layout import FolderLayout, FolderLayoutBase from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset diff --git a/monai/data/export_utils.py b/monai/data/export_utils.py new file mode 100644 index 0000000000..fb0216f78b --- /dev/null +++ b/monai/data/export_utils.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import datetime +import json +import logging +import os +from collections.abc import Mapping, Sequence +from typing import IO, Any + +import torch + +from monai.config import get_config_values +from monai.data.torchscript_utils import METADATA_FILENAME +from monai.utils import ExportMetadataKeys + +__all__ = ["load_exported_program", "save_exported_program"] + + +def save_exported_program( + exported_program: torch.export.ExportedProgram, + filename_prefix_or_stream: str | os.PathLike | IO[bytes], + include_config_vals: bool = True, + append_timestamp: bool = False, + meta_values: Mapping[str, Any] | None = None, + more_extra_files: Mapping[str, Any] | None = None, +) -> None: + """ + Save an ``ExportedProgram`` produced by :func:`torch.export.export` with metadata included + as a JSON file inside the ``.pt2`` archive. + + Examples:: + + import torch + from monai.networks.nets import UNet + + net = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=[8, 16], strides=[2]) + ep = torch.export.export(net, args=(torch.rand(1, 1, 32, 32),)) + + meta = {"name": "Test UNet", "input_dims": 2} + save_exported_program(ep, "test", meta_values=meta) + + loaded_ep, loaded_meta, _ = load_exported_program("test.pt2") + + Args: + exported_program: an ``ExportedProgram`` returned by :func:`torch.export.export`. + filename_prefix_or_stream: filename or file-like stream object. + If a string filename has no extension it becomes ``.pt2``. + include_config_vals: if True, MONAI, PyTorch, and NumPy versions are included in metadata. + append_timestamp: if True, a timestamp is appended to the filename before the extension. + meta_values: metadata values to store, compatible with JSON serialization. + more_extra_files: additional data items to include in the archive. + """ + now = datetime.datetime.now() + metadict: dict[str, Any] = {} + + if include_config_vals: + metadict.update(get_config_values()) + metadict[ExportMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() + + if meta_values is not None: + metadict.update(meta_values) + + json_data = json.dumps(metadict) + + extra_files: dict[str, Any] = {METADATA_FILENAME: json_data} + + if more_extra_files is not None: + if METADATA_FILENAME in more_extra_files: + raise ValueError(f"'{METADATA_FILENAME}' is reserved and cannot be used in more_extra_files.") + extra_files.update(more_extra_files) + + # torch.export.save requires str values; decode bytes from legacy callers (e.g. _export helper) + extra_files = {k: v.decode() if isinstance(v, bytes) else v for k, v in extra_files.items()} + + if isinstance(filename_prefix_or_stream, (str, os.PathLike)): + filename_prefix_or_stream = str(filename_prefix_or_stream) + filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) + if ext == "": + ext = ".pt2" + + if append_timestamp: + filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") + else: + filename_prefix_or_stream = filename_no_ext + ext + + torch.export.save(exported_program, filename_prefix_or_stream, extra_files=extra_files) + + +def load_exported_program( + filename_prefix_or_stream: str | os.PathLike | IO[bytes], more_extra_files: Sequence[str] = () +) -> tuple[torch.export.ExportedProgram, dict, dict]: + """ + Load an ``ExportedProgram`` from a ``.pt2`` file and extract stored JSON metadata. + + Args: + filename_prefix_or_stream: filename or file-like stream object. + more_extra_files: additional extra file names to load from the archive. + + Returns: + Triple of (ExportedProgram, metadata dict, extra files dict). + """ + extra_files: dict[str, Any] = dict.fromkeys(more_extra_files, "") + extra_files[METADATA_FILENAME] = "" + + exported_program = torch.export.load(filename_prefix_or_stream, extra_files=extra_files) + + extra_files = dict(extra_files) + + json_data = extra_files.pop(METADATA_FILENAME, "{}") + + try: + json_data_dict = json.loads(json_data) + except json.JSONDecodeError: + logging.getLogger(__name__).warning( + "Failed to parse metadata JSON from exported program, returning empty metadata." + ) + json_data_dict = {} + + return exported_program, json_data_dict, extra_files diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 12bd76ba60..3be6667605 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -84,13 +84,6 @@ class MetaTensor(MetaObj, torch.Tensor): assert torch.all(m2.affine == affine) Notes: - - Requires pytorch 1.9 or newer for full compatibility. - - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may - not work if `im` is of type `MetaTensor`. This can be resolved with - `torch.jit.trace(net, im.as_tensor())`. - - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported. - - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor. - see: https://github.com/pytorch/pytorch/issues/54457 - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 507cf411d6..3e5692bb67 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -21,10 +21,12 @@ from monai.config import get_config_values from monai.utils import JITMetadataKeys +from monai.utils.deprecate_utils import deprecated METADATA_FILENAME = "metadata.json" +@deprecated(since="1.5", removed="1.7", msg_suffix="Use monai.data.save_exported_program() instead.") def save_net_with_metadata( jit_obj: torch.nn.Module, filename_prefix_or_stream: str | IO[Any], @@ -100,6 +102,7 @@ def save_net_with_metadata( torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) +@deprecated(since="1.5", removed="1.7", msg_suffix="Use monai.data.load_exported_program() instead.") def load_net_with_metadata( filename_prefix_or_stream: str | IO[Any], map_location: torch.device | None = None, diff --git a/monai/losses/dice.py b/monai/losses/dice.py index cd76ec1323..ffa4d5472a 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -123,6 +123,11 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch weight = torch.as_tensor(weight) if weight is not None else None + if weight is not None: + if weight.numel() == 0: + raise ValueError("`weight` must not be empty.") + if weight.min() < 0: + raise ValueError("the value/values of the `weight` should be no less than 0.") self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor self.soft_label = soft_label @@ -181,7 +186,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = list(range(2, len(input.shape))) if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis @@ -208,9 +213,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: If `include_background=False`, the weight should not include the background category class 0.""" ) - if self.class_weight.min() < 0: - raise ValueError("the value/values of the `weight` should be no less than 0.") - # apply class_weight to loss + # apply class_weight to loss (weight values validated in __init__) f = f * self.class_weight.to(f) if self.reduction == LossReduction.MEAN.value: @@ -431,7 +434,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = list(range(2, len(input.shape))) if self.batch: reduce_axis = [0] + reduce_axis diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index caa237fca8..ce7254ff2c 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -122,6 +122,11 @@ def __init__( else: self.alpha = torch.as_tensor(alpha) weight = torch.as_tensor(weight) if weight is not None else None + if weight is not None: + if weight.numel() == 0: + raise ValueError("`weight` must not be empty.") + if weight.min() < 0: + raise ValueError("the value/values of the `weight` should be no less than 0.") self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -188,9 +193,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: If `include_background=False`, the weight should not include the background category class 0.""" ) - if self.class_weight.min() < 0: - raise ValueError("the value/values of the `weight` should be no less than 0.") - # apply class_weight to loss + # apply class_weight to loss (weight values validated in __init__) self.class_weight = self.class_weight.to(loss) broadcast_dims = [-1] + [1] * len(target.shape[2:]) self.class_weight = self.class_weight.view(broadcast_dims) diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py index b75433e1da..4266f6c598 100644 --- a/monai/losses/hausdorff_loss.py +++ b/monai/losses/hausdorff_loss.py @@ -190,7 +190,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: distance = pred_dt**self.alpha + target_dt**self.alpha running_f = pred_error * distance.to(device) - reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = list(range(2, len(input.shape))) if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 154f34c526..6e67329130 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -143,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + reduce_axis: list[int] = list(range(2, len(input.shape))) if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 5a240021d6..4271bbd821 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -14,6 +14,7 @@ from .trt_compiler import trt_compile from .utils import ( add_casts_around_norms, + convert_to_export, convert_to_onnx, convert_to_torchscript, convert_to_trt, diff --git a/monai/networks/blocks/feature_pyramid_network.py b/monai/networks/blocks/feature_pyramid_network.py index 96083e7c0d..7428752551 100644 --- a/monai/networks/blocks/feature_pyramid_network.py +++ b/monai/networks/blocks/feature_pyramid_network.py @@ -206,7 +206,7 @@ def __init__( def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: """ This is equivalent to self.inner_blocks[idx](x), - but torchscript doesn't support this yet + but module indexing with a variable is used for compatibility """ num_blocks = len(self.inner_blocks) if idx < 0: @@ -220,7 +220,7 @@ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor: """ This is equivalent to self.layer_blocks[idx](x), - but torchscript doesn't support this yet + but module indexing with a variable is used for compatibility """ num_blocks = len(self.layer_blocks) if idx < 0: diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 2791d2fb00..1ddb2f2839 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -112,13 +112,13 @@ def __init__( if use_combined_linear: self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript + self.to_q = self.to_k = self.to_v = nn.Identity() # placeholder for unused code path self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) else: self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) - self.qkv = nn.Identity() # add to enable torchscript + self.qkv = nn.Identity() # placeholder for unused code path self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) diff --git a/monai/networks/blocks/squeeze_and_excitation.py b/monai/networks/blocks/squeeze_and_excitation.py index 665e9020ff..41a03a69d5 100644 --- a/monai/networks/blocks/squeeze_and_excitation.py +++ b/monai/networks/blocks/squeeze_and_excitation.py @@ -81,8 +81,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.fc(y).view([b, c] + [1] * (x.ndim - 2)) result = x * y - # Residual connection is moved here instead of providing an override of forward in ResidualSELayer since - # Torchscript has an issue with using super(). + # Residual connection is applied here rather than in a forward override in ResidualSELayer. if self.add_residual: result += x diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 9ea181974a..dedfab7665 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -262,8 +262,7 @@ def instance_nvfuser_factory(dim): It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS. In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, `nn.InstanceNorm3d` will be returned instead. - This layer is based on a customized autograd function, which is not supported in TorchScript currently. - Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary. + This layer is based on a customized autograd function. Please check the following link for more details about how to install `apex`: https://github.com/NVIDIA/apex#installation diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 56f7192e4d..cabe3d5b79 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -163,7 +163,7 @@ def __init__(self, *shape: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: shape = list(self.shape) - shape[0] = x.shape[0] # done this way for Torchscript + shape[0] = x.shape[0] return x.reshape(shape) diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py index 388f93fe2d..26ad621417 100644 --- a/monai/networks/layers/vector_quantizer.py +++ b/monai/networks/layers/vector_quantizer.py @@ -26,9 +26,8 @@ class EMAQuantizer(nn.Module): that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit 58d9a2746493717a7c9252938da7efa6006f3739. - This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due - to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 - on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + When using Distributed Data Parallelism, ``torch.distributed`` synchronization is required. + Set ``ddp_sync`` to ``False`` to disable it if not running in a distributed setting. Args: spatial_dims: number of spatial dimensions of the input. @@ -146,8 +145,7 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: """ - TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the - example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + Synchronize codebook statistics across distributed processes using ``all_reduce``. Args: encodings_sum: The summation of one hot representation of what encoding was used for each diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index d2a655f981..1a68ca928d 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -158,7 +158,7 @@ def forward(self, x: torch.Tensor, x_e: torch.Tensor | None): """ x_0 = self.upsample(x) - if x_e is not None and torch.jit.isinstance(x_e, torch.Tensor): + if x_e is not None and isinstance(x_e, torch.Tensor): if self.is_pad: # handling spatial shapes due to the 2x maxpooling with odd edge lengths. dimensions = len(x.shape) - 2 diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index f87fcebea8..b483da2abd 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -36,24 +36,22 @@ __all__ = ["DiNTS", "TopologyConstruction", "TopologyInstance", "TopologySearch"] -@torch.jit.interface class CellInterface(torch.nn.Module): - """interface for torchscriptable Cell""" + """Abstract interface for Cell modules used in DiNTS.""" def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor: # type: ignore - pass + raise NotImplementedError -@torch.jit.interface class StemInterface(torch.nn.Module): - """interface for torchscriptable Stem""" + """Abstract interface for Stem modules used in DiNTS.""" def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore - pass + raise NotImplementedError class StemTS(StemInterface): - """wrapper for torchscriptable Stem""" + """Wrapper Stem that applies a sequential module.""" def __init__(self, *mod): super().__init__() @@ -375,6 +373,11 @@ def __init__( self.node_a = torch.ones((self.num_blocks + 1, self.num_depths)) else: self.node_a = node_a + # Pre-compute node activation flags as Python booleans for torch.export compatibility. + # NOTE: node_a must not be mutated after construction. + self._node_flags: list[list[bool]] = [ + [bool(self.node_a[b, d]) for d in range(self.num_depths)] for b in range(self.node_a.shape[0]) + ] # define stem operations for every block conv_type = Conv[Conv.CONV, spatial_dims] @@ -493,7 +496,7 @@ def forward(self, x: torch.Tensor): # allow multi-resolution input _mod_w: StemInterface = self.stem_down[str(d)] # type: ignore[assignment] x_out = _mod_w.forward(x) - if self.node_a[0][d]: + if self._node_flags[0][d]: inputs.append(x_out) else: inputs.append(torch.zeros_like(x_out)) @@ -507,7 +510,7 @@ def forward(self, x: torch.Tensor): _mod_up: StemInterface = self.stem_up[str(res_idx)] # type: ignore[assignment] if start: _temp = _mod_up.forward(outputs[res_idx] + _temp) - elif self.node_a[blk_idx + 1][res_idx]: + elif self._node_flags[blk_idx + 1][res_idx]: start = True _temp = _mod_up.forward(outputs[res_idx]) prediction = self.stem_finals(_temp) @@ -627,6 +630,7 @@ def __init__( def forward(self, x): """This function to be implemented by the architecture instances or search spaces.""" + raise NotImplementedError class TopologyInstance(TopologyConstruction): @@ -665,6 +669,13 @@ def __init__( use_downsample=use_downsample, device=device, ) + # Pre-compute activation flags as plain Python booleans so that the + # control flow in forward() is static and compatible with torch.export. + # NOTE: arch_code_a must not be mutated after construction; this class + # is only used at inference/re-training time, not during architecture search. + self._active_flags: list[list[bool]] = [ + [bool(self.arch_code_a[b, r]) for r in range(self.arch_code_a.shape[1])] for b in range(self.num_blocks) + ] def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: """ @@ -675,8 +686,8 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: inputs = x for blk_idx in range(self.num_blocks): outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths - for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): - if activation: + for res_idx, active in enumerate(self._active_flags[blk_idx]): + if active: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] # type: ignore[assignment] _out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index ec418469bb..e34bed46f0 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -27,8 +27,8 @@ class DynUNetSkipLayer(nn.Module): """ Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection. The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet - structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on - looping over lists of layers and accumulating lists of output tensors which must be indexed. The `heads` list is + structure. The recursive class design avoids restrictions on looping over lists of layers and accumulating + lists of output tensors which must be indexed. The `heads` list is shared amongst all the instances of this class and is used to store the output from the supervision heads during forward passes of the network. """ @@ -112,7 +112,7 @@ class DynUNet(nn.Module): deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the final feature map (from `output_block`), but also the feature maps that come from the intermediate up sample layers. - In order to unify the return type (the restriction of TorchScript), all intermediate + In order to unify the return type, all intermediate feature maps are interpolated into the same size as the final feature map and stacked together (with a new dimension in the first axis)into one single tensor. For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and @@ -169,7 +169,7 @@ def __init__( self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision self.deep_supr_num = deep_supr_num - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + # initialize the typed list of supervision head outputs self.heads: list[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num if self.deep_supervision: self.deep_supervision_heads = self.get_deep_supervision_heads() @@ -181,8 +181,8 @@ def __init__( def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): """ Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is - done recursively from the top down since a recursive nn.Module subclass is being used to be compatible - with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` + done recursively from the top down since a recursive nn.Module subclass is being used. Initially the + length of `downsamples` will be one more than that of `superheads` since the `input_block` is passed to this function as the first item in `downsamples`, however this shouldn't be associated with a supervision head. """ diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index f87120ba21..c4bd64de97 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -109,7 +109,7 @@ def forward(self, x): x = self.features(x) if isinstance(x, tuple): x = x[0] # it might be a namedtuple such as torchvision.model.InceptionOutputs - elif torch.jit.isinstance(x, dict[str, torch.Tensor]): + elif isinstance(x, dict): x = x[self.node_name] # torchvision create_feature_extractor if self.pool is not None: x = self.pool(x) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index b4d93c9afe..5c0d5fc263 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -301,20 +301,21 @@ def load_from(self, weights): d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore - @torch.jit.unused def _check_input_size(self, spatial_shape): - img_size = np.array(spatial_shape) - remainder = (img_size % np.power(self.patch_size, 5)) > 0 - if remainder.any(): - wrong_dims = (np.where(remainder)[0] + 2).tolist() + # Previously guarded by `if not torch.jit.is_scripting()` and used numpy; + # the TorchScript removal made this run on every forward call, so it uses + # plain Python arithmetic to avoid per-call tensor allocation (spatial_shape + # elements are already ints from x_in.shape[2:]). + divisor = int(self.patch_size**5) + wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if int(s) % divisor != 0] + if wrong_dims: raise ValueError( f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})" f" must be divisible by {self.patch_size}**5." ) def forward(self, x_in): - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - self._check_input_size(x_in.shape[2:]) + self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) enc1 = self.encoder2(hidden_states_out[0]) diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 07c5147cb2..29fd003ae5 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -27,7 +27,7 @@ class ViT(nn.Module): Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - ViT supports Torchscript but only works for Pytorch after 1.8. + ViT supports ``torch.export`` for model serialization. """ def __init__( diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f56c39dcd1..5119bbad61 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -30,8 +30,9 @@ from monai.apps.utils import get_logger from monai.config import PathLike +from monai.utils.deprecate_utils import deprecated from monai.utils.misc import ensure_tuple, save_obj, set_determinism -from monai.utils.module import look_up_option, optional_import +from monai.utils.module import look_up_option, optional_import, pytorch_after from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor onnx, _ = optional_import("onnx") @@ -57,6 +58,7 @@ "save_state", "convert_to_onnx", "convert_to_torchscript", + "convert_to_export", "convert_to_trt", "meshgrid_ij", "meshgrid_xy", @@ -793,6 +795,18 @@ def convert_to_onnx( return onnx_model +def _recursive_to(x, device): + """Recursively move tensors (and nested structures of tensors) to *device*.""" + if isinstance(x, torch.Tensor): + return x.to(device) + if isinstance(x, dict): + return {k: _recursive_to(v, device) for k, v in x.items()} + if isinstance(x, (tuple, list)): + return type(x)(_recursive_to(i, device) for i in x) + return x + + +@deprecated(since="1.5", removed="1.7", msg_suffix="Use convert_to_export() instead.") def convert_to_torchscript( model: nn.Module, filename_or_obj: Any | None = None, @@ -863,6 +877,84 @@ def convert_to_torchscript( return script_module +def convert_to_export( + model: nn.Module, + filename_or_obj: Any | None = None, + extra_files: dict | None = None, + verify: bool = False, + inputs: Sequence[Any] | None = None, + dynamic_shapes: dict | tuple | None = None, + device: str | torch.device | None = None, + rtol: float = 1e-4, + atol: float = 0.0, + **kwargs, +) -> torch.export.ExportedProgram: + """ + Utility to export a model using :func:`torch.export.export` and optionally save to a ``.pt2`` file, + with optional input/output data verification. + + Args: + model: source PyTorch model to export. + filename_or_obj: if not None, a file path string to save the exported program. + extra_files: map from filename to contents to store in the saved archive. + verify: whether to verify the input and output of the exported model. + If ``filename_or_obj`` is not None, loads the saved model and verifies. + inputs: input test data for export and verification. Should be a sequence of + tensors that map to positional arguments of ``model()``. + dynamic_shapes: dynamic shape specifications passed to :func:`torch.export.export`. + See PyTorch docs for format details. + device: target device to verify the model. If None, uses CUDA if available. + rtol: the relative tolerance when comparing outputs. + atol: the absolute tolerance when comparing outputs. + kwargs: additional keyword arguments for :func:`torch.export.export`. + + Returns: + A :class:`torch.export.ExportedProgram` representing the exported model. + """ + if inputs is None: + raise ValueError("Input data is required for torch.export.export.") + + model.eval() + with torch.no_grad(): + export_args = tuple(inputs) + exported = torch.export.export(model, args=export_args, dynamic_shapes=dynamic_shapes, **kwargs) + + if filename_or_obj is not None: + save_extra: dict[str, Any] = {} + if extra_files is not None: + # torch.export.save requires str values; decode bytes from legacy callers + save_extra.update({k: v.decode() if isinstance(v, bytes) else v for k, v in extra_files.items()}) + torch.export.save(exported, filename_or_obj, extra_files=save_extra if save_extra else None) + + if verify: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + verify_args = tuple(_recursive_to(i, device) for i in inputs) + + # Always verify against the in-memory export to avoid device placement + # issues that can occur when reloading from file (torch.export.load does + # not support map_location). + loaded_module = exported.module() + loaded_module.to(device) + model.to(device) + + with torch.no_grad(): + set_determinism(seed=0) + torch_out = ensure_tuple(model(*verify_args)) + set_determinism(seed=0) + export_out = ensure_tuple(loaded_module(*verify_args)) + set_determinism(seed=None) + + if len(torch_out) != len(export_out): + raise AssertionError(f"Exported model returned {len(export_out)} outputs, expected {len(torch_out)}.") + for r1, r2 in zip(torch_out, export_out): + if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): + torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore + + return exported + + def _onnx_trt_compile( onnx_model, min_shape: Sequence[int], @@ -1012,9 +1104,9 @@ def convert_to_trt( convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] - # convert the torch model to a TorchScript model on target device model = model.eval().to(target_device) min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + _use_dynamo = pytorch_after(2, 9) if use_onnx: # set the batch dim as dynamic @@ -1035,10 +1127,7 @@ def convert_to_trt( output_names=onnx_output_names, ) else: - ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) - ir_model.eval() - # convert the model through the Torch-TensorRT way - ir_model.to(target_device) + # Torch-TensorRT compilation path with torch.no_grad(): with torch.cuda.device(device=device): input_placeholder = [ @@ -1046,21 +1135,41 @@ def convert_to_trt( min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape ) ] - trt_model = torch_tensorrt.compile( - ir_model, - inputs=input_placeholder, - enabled_precisions=convert_precision, - device=torch_tensorrt.Device(f"cuda:{device}"), - ir="torchscript", - **kwargs, - ) + # Use dynamo IR (torch.export-based) which is the default in newer torch-tensorrt + if _use_dynamo: + trt_model = torch_tensorrt.compile( + model, + inputs=input_placeholder, + enabled_precisions=convert_precision, + device=torch_tensorrt.Device(f"cuda:{device}"), + ir="dynamo", + **kwargs, + ) + else: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + ir_model = convert_to_torchscript( + model, device=target_device, inputs=inputs, use_trace=use_trace + ) + trt_model = torch_tensorrt.compile( + ir_model, + inputs=input_placeholder, + enabled_precisions=convert_precision, + device=torch_tensorrt.Device(f"cuda:{device}"), + ir="torchscript", + **kwargs, + ) # verify the outputs between the TensorRT model and PyTorch model if verify: if inputs is None: raise ValueError("Missing input data for verification.") - trt_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else trt_model + if filename_or_obj is not None: + if _use_dynamo: + trt_model = torch.export.load(filename_or_obj).module() + else: + trt_model = torch.jit.load(filename_or_obj) with torch.no_grad(): set_determinism(seed=0) @@ -1068,7 +1177,9 @@ def convert_to_trt( set_determinism(seed=0) trt_out = ensure_tuple(trt_model(*inputs)) set_determinism(seed=None) - # compare TorchScript and PyTorch results + # compare TensorRT and PyTorch results + if len(torch_out) != len(trt_out): + raise AssertionError(f"TRT model returned {len(trt_out)} outputs, expected {len(torch_out)}.") for r1, r2 in zip(torch_out, trt_out): if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor): torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol) # type: ignore diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 3efc9b5e7f..bef3b0ea5a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -31,6 +31,7 @@ DiceCEReduction, DownsampleMode, EngineStatsKeys, + ExportMetadataKeys, FastMRIKeys, ForwardMode, GanKeys, diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 1249c51919..d2c32cee22 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -30,13 +30,6 @@ class DeprecatedError(Exception): pass -def warn_deprecated(obj, msg, warning_category=FutureWarning): - """ - Issue the warning message `msg`. - """ - warnings.warn(f"{obj}: {msg}", category=warning_category, stacklevel=2) - - def deprecated( since: str | None = None, removed: str | None = None, @@ -107,7 +100,7 @@ def _wrapper(*args, **kwargs): if is_removed: raise DeprecatedError(msg) if is_deprecated: - warn_deprecated(obj, msg, warning_category) + warnings.warn(f"{obj}: {msg}", category=warning_category, stacklevel=2) return call_obj(*args, **kwargs) @@ -217,7 +210,7 @@ def _wrapper(*args, **kwargs): if is_removed: raise DeprecatedError(msg) if is_deprecated: - warn_deprecated(argname, msg, warning_category) + warnings.warn(f"{argname}: {msg}", category=warning_category, stacklevel=2) return func(*args, **kwargs) @@ -317,7 +310,7 @@ def _decorator(func): def _wrapper(*args, **kwargs): if name not in sig.bind(*args, **kwargs).arguments and is_deprecated: # arg was not found so the default value is used - warn_deprecated(argname, msg, warning_category) + warnings.warn(f"{argname}: {msg}", category=warning_category, stacklevel=2) return func(*args, **kwargs) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index be00b27d73..29f7cebeb0 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -62,6 +62,8 @@ "BundlePropertyConfig", "AlgoKeys", "IgniteInfo", + "JITMetadataKeys", + "ExportMetadataKeys", ] @@ -423,6 +425,9 @@ class JITMetadataKeys(StrEnum): """ Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines and others are optionally provided by users. + + .. deprecated:: 1.5 + Use :class:`ExportMetadataKeys` instead. """ NAME = "name" @@ -431,6 +436,10 @@ class JITMetadataKeys(StrEnum): DESCRIPTION = "description" +# ExportMetadataKeys shares the same members as JITMetadataKeys; alias to avoid duplication. +ExportMetadataKeys = JITMetadataKeys + + class BoxModeName(StrEnum): """ Box mode names. diff --git a/tests/apps/detection/networks/test_retinanet.py b/tests/apps/detection/networks/test_retinanet.py index 3f4721a755..81ddb3d654 100644 --- a/tests/apps/detection/networks/test_retinanet.py +++ b/tests/apps/detection/networks/test_retinanet.py @@ -20,7 +20,7 @@ from monai.networks import eval_mode from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 from monai.utils import ensure_tuple, optional_import -from tests.test_utils import dict_product, skip_if_quick, test_onnx_save, test_script_save +from tests.test_utils import dict_product, skip_if_quick, test_export_save, test_onnx_save _, has_torchvision = optional_import("torchvision") _, has_onnxruntime = optional_import("onnxruntime") @@ -92,7 +92,9 @@ MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)] -TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])] +TEST_CASES_EXPORT = [ + [params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1]) +] @unittest.skipUnless(has_torchvision, "Requires torchvision") @@ -136,18 +138,18 @@ def test_retina_shape(self, model, input_param, input_shape): self.assertEqual(tuple(cc.shape for cc in result[net.cls_key]), expected_cls_shape) self.assertEqual(tuple(cc.shape for cc in result[net.box_reg_key]), expected_box_shape) - @parameterized.expand(TEST_CASES_TS) - def test_script(self, model, input_param, input_shape): + @parameterized.expand(TEST_CASES_EXPORT) + def test_export(self, model, input_param, input_shape): try: - idx = int(self.id().split("test_script_")[-1]) + idx = int(self.id().split("test_export_")[-1]) except BaseException: idx = 0 idx %= 3 - # test whether support torchscript + # test whether support torch.export data = torch.randn(input_shape) backbone = model(**input_param) if idx == 0: - test_script_save(backbone, data) + test_export_save(backbone, data) return feature_extractor = resnet_fpn_feature_extractor( backbone=backbone, @@ -157,7 +159,7 @@ def test_script(self, model, input_param, input_shape): returned_layers=[1, 2], ) if idx == 1: - test_script_save(feature_extractor, data) + test_export_save(feature_extractor, data) return net = RetinaNet( spatial_dims=input_param["spatial_dims"], @@ -167,9 +169,9 @@ def test_script(self, model, input_param, input_shape): size_divisible=32, ) if idx == 2: - test_script_save(net, data) + test_export_save(net, data) - @parameterized.expand(TEST_CASES_TS) + @parameterized.expand(TEST_CASES_EXPORT) @unittest.skipUnless(has_onnxruntime, "onnxruntime not installed") def test_onnx(self, model, input_param, input_shape): try: @@ -177,7 +179,7 @@ def test_onnx(self, model, input_param, input_shape): except BaseException: idx = 0 idx %= 3 - # test whether support torchscript + # test whether support ONNX export data = torch.randn(input_shape) backbone = model(**input_param) if idx == 0: diff --git a/tests/apps/detection/networks/test_retinanet_detector.py b/tests/apps/detection/networks/test_retinanet_detector.py index 6ac1efd734..6352ecd317 100644 --- a/tests/apps/detection/networks/test_retinanet_detector.py +++ b/tests/apps/detection/networks/test_retinanet_detector.py @@ -21,7 +21,7 @@ from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape from monai.networks import eval_mode, train_mode from monai.utils import optional_import -from tests.test_utils import skip_if_quick, test_script_save +from tests.test_utils import skip_if_quick, test_export_save _, has_torchvision = optional_import("torchvision") @@ -89,7 +89,7 @@ TEST_CASES = [] TEST_CASES = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_A] -TEST_CASES_TS = [TEST_CASE_1] +TEST_CASES_EXPORT = [TEST_CASE_1] class NaiveNetwork(torch.nn.Module): @@ -183,9 +183,9 @@ def test_naive_retina_detector_shape(self, input_param, input_shape): targets = [one_target] * len(input_data) result = detector.forward(input_data, targets) - @parameterized.expand(TEST_CASES_TS) - def test_script(self, input_param, input_shape): - # test whether support torchscript + @parameterized.expand(TEST_CASES_EXPORT) + def test_export(self, input_param, input_shape): + # test whether support torch.export returned_layers = [1] anchor_generator = AnchorGeneratorWithAnchorShape( feature_map_scales=(1, 2), base_anchor_shapes=((8,) * input_param["spatial_dims"],) @@ -195,7 +195,7 @@ def test_script(self, input_param, input_shape): ) with eval_mode(detector): input_data = torch.randn(input_shape) - test_script_save(detector.network, input_data) + test_export_save(detector.network, input_data) if __name__ == "__main__": diff --git a/tests/apps/detection/utils/test_anchor_box.py b/tests/apps/detection/utils/test_anchor_box.py index b537c2533c..ffba9ab958 100644 --- a/tests/apps/detection/utils/test_anchor_box.py +++ b/tests/apps/detection/utils/test_anchor_box.py @@ -18,7 +18,7 @@ from monai.apps.detection.utils.anchor_utils import AnchorGenerator, AnchorGeneratorWithAnchorShape from monai.utils import optional_import -from tests.test_utils import assert_allclose, test_script_save +from tests.test_utils import assert_allclose, test_export_save _, has_torchvision = optional_import("torchvision") @@ -67,20 +67,20 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1) @parameterized.expand(TEST_CASES_2D) - def test_script_2d(self, input_param, image_shape, feature_maps_shapes): - # test whether support torchscript + def test_export_2d(self, input_param, image_shape, feature_maps_shapes): + # test whether support torch.export anchor = AnchorGenerator(**input_param, indexing="xy") images = torch.rand(image_shape) feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes) - test_script_save(anchor, images, feature_maps) + test_export_save(anchor, images, feature_maps) @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_script_3d(self, input_param, image_shape, feature_maps_shapes): - # test whether support torchscript + def test_export_3d(self, input_param, image_shape, feature_maps_shapes): + # test whether support torch.export anchor = AnchorGeneratorWithAnchorShape(**input_param, indexing="ij") images = torch.rand(image_shape) feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes) - test_script_save(anchor, images, feature_maps) + test_export_save(anchor, images, feature_maps) if __name__ == "__main__": diff --git a/tests/bundle/test_bundle_export_checkpoint.py b/tests/bundle/test_bundle_export_checkpoint.py new file mode 100644 index 0000000000..d2f9e442f8 --- /dev/null +++ b/tests/bundle/test_bundle_export_checkpoint.py @@ -0,0 +1,129 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import tempfile +import unittest +from pathlib import Path + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from monai.data import load_exported_program +from monai.networks import save_state +from tests.test_utils import command_line_tests, skip_if_windows + +TESTS_PATH = Path(__file__).parents[1] + +# key_in_ckpt +TEST_CASE_1 = [""] +TEST_CASE_2 = ["model"] + + +@skip_if_windows +class TestExportCheckpoint(unittest.TestCase): + def setUp(self): + self._orig_cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES") + + def tearDown(self): + if self._orig_cuda_env is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = self._orig_cuda_env + else: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_export(self, key_in_ckpt): + meta_file = os.path.join(TESTS_PATH, "testing_data", "metadata.json") + config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json") + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.yaml") + + ckpt_file = os.path.join(tempdir, "model.pt") + pt2_file = os.path.join(tempdir, "model.pt2") + + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + parser.read_config(config_file) + net = parser.get_parsed_content("network_def") + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) + + cmd = [ + "coverage", + "run", + "-m", + "monai.bundle", + "export_checkpoint", + "network_def", + "--filepath", + pt2_file, + "--meta_file", + meta_file, + "--config_file", + f"['{config_file}','{def_args_file}']", + "--ckpt_file", + ckpt_file, + "--key_in_ckpt", + key_in_ckpt, + "--args_file", + def_args_file, + "--input_shape", + "[1, 1, 96, 96, 96]", + ] + command_line_tests(cmd) + self.assertTrue(os.path.exists(pt2_file)) + + _, _metadata, extra_files = load_exported_program( + pt2_file, more_extra_files=["inference.json", "def_args.json"] + ) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_default_value(self, key_in_ckpt): + config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json") + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg"} + def_args_file = os.path.join(tempdir, "def_args.yaml") + ckpt_file = os.path.join(tempdir, "models", "model.pt") + pt2_file = os.path.join(tempdir, "models", "model.pt2") + + parser = ConfigParser() + parser.export_config_file(config=def_args, filepath=def_args_file) + parser.read_config(config_file) + net = parser.get_parsed_content("network_def") + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) + + # check with default value + cmd = [ + "coverage", + "run", + "-m", + "monai.bundle", + "export_checkpoint", + "--key_in_ckpt", + key_in_ckpt, + "--config_file", + config_file, + "--bundle_root", + tempdir, + "--input_shape", + "[1, 1, 96, 96, 96]", + ] + command_line_tests(cmd) + self.assertTrue(os.path.exists(pt2_file)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index c0e53fd24c..e10b35d8de 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -224,21 +224,22 @@ def test_get_set_meta_fns(self): self.assertTrue(get_track_meta()) @parameterized.expand(TEST_DEVICES) - def test_torchscript(self, device): + def test_export(self, device): shape = (1, 3, 10, 8) im, _ = self.get_im(shape, device=device) conv = torch.nn.Conv2d(im.shape[1], 5, 3) conv.to(device) im_conv = conv(im) - traced_fn = torch.jit.trace(conv, im.as_tensor()) + exported = torch.export.export(conv, args=(im.as_tensor(),)) # save it, load it, use it with tempfile.TemporaryDirectory() as tmp_dir: - fname = os.path.join(tmp_dir, "im.pt") - torch.jit.save(traced_fn, f=fname) - traced_fn = torch.jit.load(fname) - out = traced_fn(im) + fname = os.path.join(tmp_dir, "im.pt2") + torch.export.save(exported, fname) + loaded = torch.export.load(fname) + out = loaded.module()(im.as_tensor()) self.assertIsInstance(out, torch.Tensor) - self.check(out, im_conv, ids=False) + # exported module returns plain Tensor, compare values only + assert_allclose(out, im_conv) def test_pickling(self): m, _ = self.get_im() diff --git a/tests/data/test_export_utils.py b/tests/data/test_export_utils.py new file mode 100644 index 0000000000..e721b9cb30 --- /dev/null +++ b/tests/data/test_export_utils.py @@ -0,0 +1,103 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest + +import torch + +from monai.config import get_config_values +from monai.data import load_exported_program, save_exported_program +from monai.utils import ExportMetadataKeys + + +class TestModule(torch.nn.Module): + __test__ = False + + def forward(self, x): + return x + 10 + + +class TestExportUtils(unittest.TestCase): + def test_save_exported_program(self): + """Save an exported program without metadata to a file.""" + ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),)) + with tempfile.TemporaryDirectory() as tempdir: + save_exported_program(ep, f"{tempdir}/test") + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt2")) + + def test_save_exported_program_ext(self): + """Save an exported program to a file with custom extension.""" + ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),)) + with tempfile.TemporaryDirectory() as tempdir: + save_exported_program(ep, f"{tempdir}/test.zip") + self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) + + def test_save_with_metadata(self): + """Save an exported program with metadata to a file.""" + ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),)) + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_exported_program(ep, f"{tempdir}/test", meta_values=test_metadata) + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt2")) + + def test_load_exported_program(self): + """Save then load an exported program with no extra metadata.""" + ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),)) + + with tempfile.TemporaryDirectory() as tempdir: + save_exported_program(ep, f"{tempdir}/test") + loaded_ep, meta, extra_files = load_exported_program(f"{tempdir}/test.pt2") + + del meta[ExportMetadataKeys.TIMESTAMP.value] + self.assertEqual(meta, get_config_values()) + self.assertEqual(extra_files, {}) + + # Verify the loaded program produces the same output + result = loaded_ep.module()(torch.tensor(5.0)) + self.assertEqual(result.item(), 15.0) + + def test_load_with_metadata(self): + """Save then load an exported program with metadata.""" + ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),)) + test_metadata = {"foo": [1, 2], "bar": "string"} + + with tempfile.TemporaryDirectory() as tempdir: + save_exported_program(ep, f"{tempdir}/test", meta_values=test_metadata) + _, meta, extra_files = load_exported_program(f"{tempdir}/test.pt2") + + del meta[ExportMetadataKeys.TIMESTAMP.value] + + test_compare = get_config_values() + test_compare.update(test_metadata) + self.assertEqual(meta, test_compare) + self.assertEqual(extra_files, {}) + + def test_save_load_more_extra_files(self): + """Save then load extra file data from an exported program.""" + ep = torch.export.export(TestModule(), args=(torch.tensor(1.0),)) + test_metadata = {"foo": [1, 2], "bar": "string"} + more_extra_files = {"test.txt": "This is test data"} + + with tempfile.TemporaryDirectory() as tempdir: + save_exported_program(ep, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) + self.assertTrue(os.path.isfile(f"{tempdir}/test.pt2")) + + _, _, loaded_extra_files = load_exported_program(f"{tempdir}/test.pt2", more_extra_files=("test.txt",)) + self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/test_retinanet_predict_utils.py b/tests/integration/test_retinanet_predict_utils.py index d909699469..25d310152b 100644 --- a/tests/integration/test_retinanet_predict_utils.py +++ b/tests/integration/test_retinanet_predict_utils.py @@ -81,7 +81,7 @@ TEST_CASES = [] TEST_CASES = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_A] -TEST_CASES_TS = [TEST_CASE_1] +TEST_CASES_EXPORT = [TEST_CASE_1] class NaiveNetwork(torch.nn.Module): diff --git a/tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py b/tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py index 35a24cd0ca..a1491a29db 100644 --- a/tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py +++ b/tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py @@ -18,6 +18,7 @@ from parameterized import parameterized from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss +from tests.test_utils import test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -152,11 +153,11 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(reduction=None)(pred, target) + def test_export(self): + input_param, input_data, _ = TEST_CASES[0] + loss = LocalNormalizedCrossCorrelationLoss(**input_param) + test_export_save(loss, input_data["pred"], input_data["target"]) -# def test_script(self): -# input_param, input_data, _ = TEST_CASES[0] -# loss = LocalNormalizedCrossCorrelationLoss(**input_param) -# test_script_save(loss, input_data["pred"], input_data["target"]) if __name__ == "__main__": unittest.main() diff --git a/tests/losses/test_dice_ce_loss.py b/tests/losses/test_dice_ce_loss.py index 97c7ae5050..3c20319516 100644 --- a/tests/losses/test_dice_ce_loss.py +++ b/tests/losses/test_dice_ce_loss.py @@ -18,6 +18,7 @@ from parameterized import parameterized from monai.losses import DiceCELoss +from tests.test_utils import test_export_save TEST_CASES = [ [ # shape: (2, 2, 3), (2, 1, 3) @@ -113,10 +114,10 @@ def test_ill_shape3(self): # loss = DiceCELoss(reduction="none") # loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) - # def test_script(self): - # loss = DiceCELoss() - # test_input = torch.ones(2, 2, 8, 8) - # test_script_save(loss, test_input, test_input) + def test_export(self): + loss = DiceCELoss() + test_input = torch.ones(2, 2, 8, 8) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/losses/test_dice_focal_loss.py b/tests/losses/test_dice_focal_loss.py index 98ea475ded..b9359e56f5 100644 --- a/tests/losses/test_dice_focal_loss.py +++ b/tests/losses/test_dice_focal_loss.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import DiceFocalLoss, DiceLoss, FocalLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save class TestDiceFocalLoss(unittest.TestCase): @@ -85,10 +85,10 @@ def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): DiceFocalLoss(lambda_dice=-1.0) - def test_script(self): + def test_export(self): loss = DiceFocalLoss() test_input = torch.ones(2, 1, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) @parameterized.expand( [ diff --git a/tests/losses/test_dice_loss.py b/tests/losses/test_dice_loss.py index 66c038783a..12658df17c 100644 --- a/tests/losses/test_dice_loss.py +++ b/tests/losses/test_dice_loss.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import DiceLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -217,10 +217,10 @@ def test_input_warnings(self): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) - def test_script(self): + def test_export(self): loss = DiceLoss() test_input = torch.ones(2, 1, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/losses/test_ds_loss.py b/tests/losses/test_ds_loss.py index 586826aafe..e71d976d8a 100644 --- a/tests/losses/test_ds_loss.py +++ b/tests/losses/test_ds_loss.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import DeepSupervisionLoss, DiceCELoss, DiceFocalLoss, DiceLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES_DICECE = [ [ @@ -151,10 +151,10 @@ def test_ill_reduction(self): loss = DeepSupervisionLoss(DiceCELoss(reduction="none")) loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) - def test_script(self): + def test_export(self): loss = DeepSupervisionLoss(DiceCELoss()) test_input = torch.ones(2, 2, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) class TestDSLossDiceCE2(unittest.TestCase): diff --git a/tests/losses/test_focal_loss.py b/tests/losses/test_focal_loss.py index 35017ec898..3e11e0298f 100644 --- a/tests/losses/test_focal_loss.py +++ b/tests/losses/test_focal_loss.py @@ -21,7 +21,7 @@ from monai.losses import FocalLoss from monai.networks import one_hot -from tests.test_utils import TEST_DEVICES, test_script_save +from tests.test_utils import TEST_DEVICES, test_export_save TEST_CASES = [] for case in TEST_DEVICES: @@ -376,11 +376,11 @@ def test_warnings(self): loss = FocalLoss(include_background=False, use_softmax=True, alpha=0.5) loss(chn_input, chn_target) - def test_script(self): + def test_export(self): for use_softmax in [True, False]: loss = FocalLoss(use_softmax=use_softmax) test_input = torch.ones(2, 2, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) @parameterized.expand(TEST_ALPHA_BROADCASTING) def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax): diff --git a/tests/losses/test_generalized_dice_focal_loss.py b/tests/losses/test_generalized_dice_focal_loss.py index 2af4aa68db..05f9e68c20 100644 --- a/tests/losses/test_generalized_dice_focal_loss.py +++ b/tests/losses/test_generalized_dice_focal_loss.py @@ -17,7 +17,7 @@ import torch from monai.losses import FocalLoss, GeneralizedDiceFocalLoss, GeneralizedDiceLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save class TestGeneralizedDiceFocalLoss(unittest.TestCase): @@ -75,10 +75,10 @@ def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): GeneralizedDiceFocalLoss(lambda_gdl=-1.0) - def test_script(self): + def test_export(self): loss = GeneralizedDiceFocalLoss() test_input = torch.ones(2, 1, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/losses/test_generalized_dice_loss.py b/tests/losses/test_generalized_dice_loss.py index 8549e87482..ace84d7af2 100644 --- a/tests/losses/test_generalized_dice_loss.py +++ b/tests/losses/test_generalized_dice_loss.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import GeneralizedDiceLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -211,10 +211,10 @@ def test_batch(self): loss = generalized_dice_loss(prediction, target) self.assertIsNotNone(loss.grad_fn) - def test_script(self): + def test_export(self): loss = GeneralizedDiceLoss() test_input = torch.ones(2, 1, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/losses/test_multi_scale.py b/tests/losses/test_multi_scale.py index 87ccca7676..264d4521d4 100644 --- a/tests/losses/test_multi_scale.py +++ b/tests/losses/test_multi_scale.py @@ -18,7 +18,7 @@ from monai.losses import DiceLoss from monai.losses.multi_scale import MultiScaleLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -76,10 +76,10 @@ def test_ill_opts(self, kwargs, input, target): with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(**kwargs)(input, target) - def test_script(self): + def test_export(self): input_param, input_data, expected_val = TEST_CASES[0] loss = MultiScaleLoss(**input_param) - test_script_save(loss, input_data["y_pred"], input_data["y_true"]) + test_export_save(loss, input_data["y_pred"], input_data["y_true"]) if __name__ == "__main__": diff --git a/tests/losses/test_spectral_loss.py b/tests/losses/test_spectral_loss.py index 8a4988a30d..8b0d3d1193 100644 --- a/tests/losses/test_spectral_loss.py +++ b/tests/losses/test_spectral_loss.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import JukeboxLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES = [ [ @@ -76,10 +76,10 @@ def test_3d_shape(self): results = JukeboxLoss(spatial_dims=3, reduction="none").forward(**TEST_CASES[2][1]) self.assertEqual(results.shape, (1, 2, 2, 2, 3)) - def test_script(self): + def test_export(self): loss = JukeboxLoss(spatial_dims=2) test_input = torch.ones(2, 1, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/losses/test_ssim_loss.py b/tests/losses/test_ssim_loss.py index 7fa593b956..cb79e748f7 100644 --- a/tests/losses/test_ssim_loss.py +++ b/tests/losses/test_ssim_loss.py @@ -18,8 +18,7 @@ from monai.losses.ssim_loss import SSIMLoss from monai.utils import set_determinism - -# from tests.utils import test_script_save +from tests.test_utils import test_export_save class TestSSIMLoss(unittest.TestCase): @@ -49,10 +48,10 @@ def test_shape(self): expected_val = [[0.9121], [0.9971]] np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) - # def test_script(self): - # loss = SSIMLoss(spatial_dims=2) - # test_input = torch.ones(2, 2, 16, 16) - # test_script_save(loss, test_input, test_input) + def test_export(self): + loss = SSIMLoss(spatial_dims=2) + test_input = torch.ones(2, 2, 16, 16) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/losses/test_tversky_loss.py b/tests/losses/test_tversky_loss.py index 32303434ca..a92c1804bb 100644 --- a/tests/losses/test_tversky_loss.py +++ b/tests/losses/test_tversky_loss.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import TverskyLoss -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -188,10 +188,10 @@ def test_input_warnings(self, include_background, softmax, to_onehot_y): loss = TverskyLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y) loss.forward(chn_input, chn_target) - def test_script(self): + def test_export(self): loss = TverskyLoss() test_input = torch.ones(2, 1, 8, 8) - test_script_save(loss, test_input, test_input) + test_export_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/min_tests.py b/tests/min_tests.py index 2d68f099a7..a1390acc66 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -203,6 +203,7 @@ def run_testsuit(): "test_bundle_verify_metadata", "test_bundle_verify_net", "test_bundle_ckpt_export", + "test_bundle_export_checkpoint", "test_bundle_utils", "test_bundle_init_bundle", "test_fastmri_reader", diff --git a/tests/networks/blocks/test_dynunet_block.py b/tests/networks/blocks/test_dynunet_block.py index 4cf68912be..1a8f23b18c 100644 --- a/tests/networks/blocks/test_dynunet_block.py +++ b/tests/networks/blocks/test_dynunet_block.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding -from tests.test_utils import dict_product, test_script_save +from tests.test_utils import dict_product, test_export_save TEST_CASE_RES_BASIC_BLOCK = [] for params in dict_product( @@ -83,13 +83,13 @@ def test_ill_arg(self): with self.assertRaises(AssertionError): UnetResBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch") - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASE_RES_BASIC_BLOCK[0] for net_type in (UnetResBlock, UnetBasicBlock): net = net_type(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) class TestUpBlock(unittest.TestCase): @@ -100,13 +100,13 @@ def test_shape(self, input_param, input_shape, expected_shape, skip_shape): result = net(torch.randn(input_shape), torch.randn(skip_shape)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _, skip_shape = TEST_UP_BLOCK[0] net = UnetUpBlock(**input_param) test_data = torch.randn(input_shape) skip_data = torch.randn(skip_shape) - test_script_save(net, test_data, skip_data) + test_export_save(net, test_data, skip_data) if __name__ == "__main__": diff --git a/tests/networks/blocks/test_se_block.py b/tests/networks/blocks/test_se_block.py index d799cd095c..9b793e4c7f 100644 --- a/tests/networks/blocks/test_se_block.py +++ b/tests/networks/blocks/test_se_block.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.blocks import SEBlock from monai.networks.layers.factories import Act, Norm -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -70,11 +70,11 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASES[0] net = SEBlock(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) def test_ill_arg(self): with self.assertRaises(ValueError): diff --git a/tests/networks/blocks/test_se_blocks.py b/tests/networks/blocks/test_se_blocks.py index b40f3a0955..034cb8f20f 100644 --- a/tests/networks/blocks/test_se_blocks.py +++ b/tests/networks/blocks/test_se_blocks.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.blocks import ChannelSELayer, ResidualSELayer -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES = [ # single channel 3D, batch 16 [{"spatial_dims": 2, "in_channels": 4, "r": 3}, (7, 4, 64, 48), (7, 4, 64, 48)], # 4-channel 2D, batch 7 @@ -48,11 +48,11 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASES[0] net = ChannelSELayer(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) def test_ill_arg(self): with self.assertRaises(ValueError): @@ -67,11 +67,11 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASES[0] net = ResidualSELayer(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/blocks/test_selfattention.py b/tests/networks/blocks/test_selfattention.py index af52918612..7d4bd6b5a7 100644 --- a/tests/networks/blocks/test_selfattention.py +++ b/tests/networks/blocks/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.test_utils import assert_allclose, test_script_save +from tests.test_utils import assert_allclose, test_export_save einops, has_einops = optional_import("einops") @@ -192,7 +192,7 @@ def count_sablock_params(*args, **kwargs): @parameterized.expand([[True, False], [True, True], [False, True], [False, False]]) @skipUnless(has_einops, "Requires einops") - def test_script(self, include_fc, use_combined_linear): + def test_export(self, include_fc, use_combined_linear): input_param = { "hidden_size": 360, "num_heads": 4, @@ -205,7 +205,7 @@ def test_script(self, include_fc, use_combined_linear): net = SABlock(**input_param) input_shape = (2, 512, 360) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) @skipUnless(has_einops, "Requires einops") def test_flash_attention(self): diff --git a/tests/networks/blocks/test_subpixel_upsample.py b/tests/networks/blocks/test_subpixel_upsample.py index f4075f5099..e38d441027 100644 --- a/tests/networks/blocks/test_subpixel_upsample.py +++ b/tests/networks/blocks/test_subpixel_upsample.py @@ -20,7 +20,7 @@ from monai.networks import eval_mode from monai.networks.blocks import SubpixelUpsample from monai.networks.layers.factories import Conv -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASE_SUBPIXEL = [] for inch in range(1, 5): @@ -75,11 +75,11 @@ def test_subpixel_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASE_SUBPIXEL[0] net = SubpixelUpsample(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/blocks/test_unetr_block.py b/tests/networks/blocks/test_unetr_block.py index 0073efc609..25789093ea 100644 --- a/tests/networks/blocks/test_unetr_block.py +++ b/tests/networks/blocks/test_unetr_block.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.blocks.dynunet_block import get_padding from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock -from tests.test_utils import dict_product, test_script_save +from tests.test_utils import dict_product, test_export_save def _get_out_size(params): @@ -111,12 +111,12 @@ def test_ill_arg(self): with self.assertRaises(AssertionError): UnetrBasicBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch") - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASE_UNETR_BASIC_BLOCK[0] net = UnetrBasicBlock(**input_param) with eval_mode(net): test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) class TestUpBlock(unittest.TestCase): @@ -127,12 +127,12 @@ def test_shape(self, input_param, input_shape, expected_shape, skip_shape): result = net(torch.randn(input_shape), torch.randn(skip_shape)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _, skip_shape = TEST_UP_BLOCK[0] net = UnetrUpBlock(**input_param) test_data = torch.randn(input_shape) skip_data = torch.randn(skip_shape) - test_script_save(net, test_data, skip_data) + test_export_save(net, test_data, skip_data) class TestPrUpBlock(unittest.TestCase): @@ -143,11 +143,11 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_PRUP_BLOCK[0] net = UnetrPrUpBlock(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/dints/test_dints_mixop.py b/tests/networks/nets/dints/test_dints_mixop.py index ea78514fa5..4c1fbf291b 100644 --- a/tests/networks/nets/dints/test_dints_mixop.py +++ b/tests/networks/nets/dints/test_dints_mixop.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks.nets.dints import Cell, MixedOp -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASES_3D = [ [ @@ -75,9 +75,9 @@ def test_mixop_2d(self, input_param, ops, weight, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @parameterized.expand(TEST_CASES_3D) - def test_script(self, input_param, ops, weight, input_shape, expected_shape): + def test_export(self, input_param, ops, weight, input_shape, expected_shape): net = MixedOp(ops=Cell.OPS3D, **input_param) - test_script_save(net, torch.randn(input_shape), weight) + test_export_save(net, torch.randn(input_shape), weight) if __name__ == "__main__": diff --git a/tests/networks/nets/test_ahnet.py b/tests/networks/nets/test_ahnet.py index 7facf9af24..f72bcee30a 100644 --- a/tests/networks/nets/test_ahnet.py +++ b/tests/networks/nets/test_ahnet.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet -from tests.test_utils import skip_if_quick, test_pretrained_networks, test_script_save +from tests.test_utils import skip_if_quick, test_export_save, test_pretrained_networks device = "cuda" if torch.cuda.is_available() else "cpu" @@ -180,15 +180,15 @@ def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @skip_if_quick - def test_script(self): + def test_export(self): # test 2D network net = AHNet(spatial_dims=2, out_channels=2) test_data = torch.randn(1, 1, 128, 64) - test_script_save(net, test_data) + test_export_save(net, test_data) # test 3D network net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode="nearest") test_data = torch.randn(1, 1, 32, 32, 64) - test_script_save(net, test_data) + test_export_save(net, test_data) class TestAHNETWithPretrain(unittest.TestCase): diff --git a/tests/networks/nets/test_autoencoder.py b/tests/networks/nets/test_autoencoder.py index dcf90b809a..d9623eca4c 100644 --- a/tests/networks/nets/test_autoencoder.py +++ b/tests/networks/nets/test_autoencoder.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.layers import Act from monai.networks.nets import AutoEncoder -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -81,10 +81,10 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = AutoEncoder(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2)) test_data = torch.randn(2, 1, 32, 32) - test_script_save(net, test_data) + test_export_save(net, test_data) def test_channel_stride_difference(self): with self.assertRaises(ValueError): diff --git a/tests/networks/nets/test_daf3d.py b/tests/networks/nets/test_daf3d.py index e707cfb272..5be6d9e9c1 100644 --- a/tests/networks/nets/test_daf3d.py +++ b/tests/networks/nets/test_daf3d.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.nets import DAF3D from monai.utils import optional_import -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save _, has_tv = optional_import("torchvision") @@ -51,11 +51,11 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @unittest.skip("daf3d: torchscript not currently supported") - def test_script(self): + @unittest.skip("daf3d: export not currently supported") + def test_export(self): net = DAF3D(in_channels=1, out_channels=1) test_data = torch.randn(16, 1, 32, 32) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_densenet.py b/tests/networks/nets/test_densenet.py index fe0b6c3bf0..fe62a59b6e 100644 --- a/tests/networks/nets/test_densenet.py +++ b/tests/networks/nets/test_densenet.py @@ -21,7 +21,7 @@ from monai.networks import eval_mode from monai.networks.nets import DenseNet121, Densenet169, DenseNet264, densenet201 from monai.utils import optional_import -from tests.test_utils import skip_if_downloading_fails, skip_if_quick, test_script_save +from tests.test_utils import skip_if_downloading_fails, skip_if_quick, test_export_save if TYPE_CHECKING: import torchvision @@ -55,7 +55,7 @@ for model in [DenseNet121, Densenet169, densenet201, DenseNet264]: TEST_CASES.append([model, *case]) -TEST_SCRIPT_CASES = [[model, *TEST_CASE_1] for model in [DenseNet121, Densenet169, densenet201, DenseNet264]] +TEST_EXPORT_CASES = [[model, *TEST_CASE_1] for model in [DenseNet121, Densenet169, densenet201, DenseNet264]] TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2 DenseNet121, @@ -110,11 +110,11 @@ def test_densenet_shape(self, model, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @parameterized.expand(TEST_SCRIPT_CASES) - def test_script(self, model, input_param, input_shape, expected_shape): + @parameterized.expand(TEST_EXPORT_CASES) + def test_export(self, model, input_param, input_shape, expected_shape): net = model(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_dints_network.py b/tests/networks/nets/test_dints_network.py index 80ade00db7..96352b1f80 100644 --- a/tests/networks/nets/test_dints_network.py +++ b/tests/networks/nets/test_dints_network.py @@ -19,7 +19,7 @@ from monai.networks.nets import DiNTS, TopologyInstance, TopologySearch from monai.networks.nets.dints import Cell -from tests.test_utils import skip_if_quick, test_script_save +from tests.test_utils import skip_if_quick, test_export_save TEST_CASES_3D = [ [ @@ -153,14 +153,22 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect self.assertTrue(isinstance(net.weight_parameters(), list)) -class TestDintsTS(unittest.TestCase): +class TestDintsExport(unittest.TestCase): @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D) - def test_script(self, dints_grid_params, dints_params, input_shape, _): - grid = TopologyInstance(**dints_grid_params) + def test_export(self, dints_grid_params, dints_params, input_shape, _): + num_blocks = dints_grid_params["num_blocks"] + num_depths = dints_grid_params["num_depths"] + _cell = Cell(1, 1, 0, spatial_dims=dints_grid_params["spatial_dims"]) + num_cell_ops = len(_cell.OPS) + arch_code_a = np.ones((num_blocks, 3 * num_depths - 2)) + arch_code_c = np.random.randint(num_cell_ops, size=(num_blocks, 3 * num_depths - 2)) + dints_grid_params["arch_code"] = [arch_code_a, arch_code_c] dints_grid_params["device"] = "cpu" + grid = TopologyInstance(**dints_grid_params) dints_params["dints_space"] = grid - net = DiNTS(**dints_params).to(dints_grid_params["device"]) - test_script_save(net, torch.randn(input_shape).to(dints_grid_params["device"])) + dints_params["node_a"] = torch.ones((num_blocks + 1, num_depths)) + net = DiNTS(**dints_params).to("cpu") + test_export_save(net, torch.randn(input_shape).to("cpu")) if __name__ == "__main__": diff --git a/tests/networks/nets/test_discriminator.py b/tests/networks/nets/test_discriminator.py index 8f460a2450..d866b615e2 100644 --- a/tests/networks/nets/test_discriminator.py +++ b/tests/networks/nets/test_discriminator.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import Discriminator -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASE_0 = [ {"in_shape": (1, 64, 64), "channels": (2, 4, 8), "strides": (2, 2, 2), "num_res_units": 0}, @@ -49,10 +49,10 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = Discriminator(in_shape=(1, 64, 64), channels=(2, 4), strides=(2, 2), num_res_units=0) test_data = torch.rand(16, 1, 64, 64) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_dynunet.py b/tests/networks/nets/test_dynunet.py index c2c9369923..8d438bc3aa 100644 --- a/tests/networks/nets/test_dynunet.py +++ b/tests/networks/nets/test_dynunet.py @@ -20,7 +20,7 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet from monai.utils import optional_import -from tests.test_utils import assert_allclose, dict_product, skip_if_no_cuda, skip_if_windows, test_script_save +from tests.test_utils import assert_allclose, dict_product, skip_if_no_cuda, skip_if_windows, test_export_save InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") @@ -134,11 +134,11 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] net = DynUNet(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) @skip_if_no_cuda diff --git a/tests/networks/nets/test_efficientnet.py b/tests/networks/nets/test_efficientnet.py index e76d5a6d5a..c2716c6522 100644 --- a/tests/networks/nets/test_efficientnet.py +++ b/tests/networks/nets/test_efficientnet.py @@ -29,7 +29,7 @@ get_efficientnet_image_size, ) from monai.utils import optional_import -from tests.test_utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save +from tests.test_utils import skip_if_downloading_fails, skip_if_quick, test_export_save, test_pretrained_networks TESTS_PATH = Path(__file__).parents[2] @@ -370,12 +370,12 @@ def test_func_get_efficientnet_input_shape(self): expected_shape = get_expected_model_shape(model) self.assertEqual(result_shape, expected_shape) - def test_script(self): + def test_export(self): with skip_if_downloading_fails(): net = EfficientNetBN(model_name="efficientnet-b0", spatial_dims=2, in_channels=3, num_classes=1000) net.set_swish(memory_efficient=False) # at the moment custom memory efficient swish is not exportable with jit test_data = torch.randn(1, 3, 224, 224) - test_script_save(net, test_data) + test_export_save(net, test_data) class TestExtractFeatures(unittest.TestCase): diff --git a/tests/networks/nets/test_generator.py b/tests/networks/nets/test_generator.py index 9ec08194e9..4e73212113 100644 --- a/tests/networks/nets/test_generator.py +++ b/tests/networks/nets/test_generator.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import Generator -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_CASE_0 = [ {"latent_shape": (64,), "start_shape": (8, 8, 8), "channels": (8, 4, 1), "strides": (2, 2, 2), "num_res_units": 0}, @@ -49,10 +49,10 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = Generator(latent_shape=(64,), start_shape=(8, 8, 8), channels=(8, 1), strides=(2, 2), num_res_units=2) test_data = torch.rand(16, 64) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_highresnet.py b/tests/networks/nets/test_highresnet.py index 1384dfaeff..354153e22d 100644 --- a/tests/networks/nets/test_highresnet.py +++ b/tests/networks/nets/test_highresnet.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import HighResNet -from tests.test_utils import DistTestCase, TimedCall, test_script_save +from tests.test_utils import DistTestCase, TimedCall, test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -56,11 +56,11 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @TimedCall(seconds=800, force_quit=True) - def test_script(self): + def test_export(self): input_param, input_shape, expected_shape = TEST_CASE_1 net = HighResNet(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data, rtol=1e-4, atol=1e-4) + test_export_save(net, test_data, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/networks/nets/test_hovernet.py b/tests/networks/nets/test_hovernet.py index 58657e6dea..23f6f052d1 100644 --- a/tests/networks/nets/test_hovernet.py +++ b/tests/networks/nets/test_hovernet.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode, train_mode from monai.networks.nets import HoVerNet from monai.networks.nets.hovernet import _DenseLayerDecoder -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -178,11 +178,11 @@ def test_decoder_padding_shape(self, input_param, input_shape, expected_shapes): else: pass - def test_script(self): + def test_export(self): for padding_flag in [True, False]: net = HoVerNet(mode=HoVerNet.Mode.FAST, decoder_padding=padding_flag) test_data = torch.randn(1, 3, 256, 256) - test_script_save(net, test_data) + test_export_save(net, test_data) def test_ill_input_shape(self): net = HoVerNet(mode=HoVerNet.Mode.FAST) diff --git a/tests/networks/nets/test_milmodel.py b/tests/networks/nets/test_milmodel.py index 15fda15a11..0a3c5e2fd9 100644 --- a/tests/networks/nets/test_milmodel.py +++ b/tests/networks/nets/test_milmodel.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.nets import MILModel from monai.utils.module import optional_import -from tests.test_utils import skip_if_downloading_fails, test_script_save +from tests.test_utils import skip_if_downloading_fails, test_export_save models, _ = optional_import("torchvision.models") @@ -81,11 +81,11 @@ def test_ill_args(self): mil_mode="att_trans_pyramid", ) - def test_script(self): + def test_export(self): input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0] net = MILModel(**input_param) test_data = torch.randn(input_shape, dtype=torch.float) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_net_adapter.py b/tests/networks/nets/test_net_adapter.py index 08344900e4..5a9a622269 100644 --- a/tests/networks/nets/test_net_adapter.py +++ b/tests/networks/nets/test_net_adapter.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import NetAdapter, resnet18 -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -54,14 +54,14 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_0]) - def test_script(self, input_param, input_shape, expected_shape): + def test_export(self, input_param, input_shape, expected_shape): spatial_dims = input_param["dim"] stride = (1, 2, 2)[:spatial_dims] model = resnet18(spatial_dims=spatial_dims, conv1_t_stride=stride) input_param["model"] = model net = NetAdapter(**input_param).to("cpu") test_data = torch.randn(input_shape).to("cpu") - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_patch_gan_dicriminator.py b/tests/networks/nets/test_patch_gan_dicriminator.py index 184f76fa9d..dfd434678e 100644 --- a/tests/networks/nets/test_patch_gan_dicriminator.py +++ b/tests/networks/nets/test_patch_gan_dicriminator.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save TEST_PATCHGAN = [ [ @@ -124,7 +124,7 @@ def test_shape(self, input_param, input_data, expected_shape_feature, expected_s self.assertEqual(tuple(result[0].shape), expected_shape_feature) self.assertEqual(tuple(result[-1].shape), expected_shape_output) - def test_script(self): + def test_export(self): net = PatchDiscriminator( num_layers_d=3, spatial_dims=2, @@ -138,7 +138,7 @@ def test_script(self): dropout=0.1, ) i = torch.rand([1, 3, 256, 512]) - test_script_save(net, i) + test_export_save(net, i) class TestMultiscalePatchGAN(unittest.TestCase): @@ -156,7 +156,7 @@ def test_too_small_shape(self): with self.assertRaises(AssertionError): MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) - def test_script(self): + def test_export(self): net = MultiScalePatchDiscriminator( num_d=2, num_layers_d=3, @@ -172,7 +172,7 @@ def test_script(self): minimum_size_im=256, ) i = torch.rand([1, 3, 256, 512]) - test_script_save(net, i) + test_export_save(net, i) if __name__ == "__main__": diff --git a/tests/networks/nets/test_quicknat.py b/tests/networks/nets/test_quicknat.py index 6653965c08..becaa670e4 100644 --- a/tests/networks/nets/test_quicknat.py +++ b/tests/networks/nets/test_quicknat.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.nets import Quicknat from monai.utils import optional_import -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save _, has_se = optional_import("squeeze_and_excitation") @@ -47,10 +47,10 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = Quicknat(num_classes=1, num_channels=1) test_data = torch.randn(16, 1, 32, 32) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_resnet.py b/tests/networks/nets/test_resnet.py index 371ec89682..c0e047e0b4 100644 --- a/tests/networks/nets/test_resnet.py +++ b/tests/networks/nets/test_resnet.py @@ -43,7 +43,7 @@ skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, - test_script_save, + test_export_save, ) if TYPE_CHECKING: @@ -228,7 +228,7 @@ for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]: TEST_CASES.append([ResNet, *case]) -TEST_SCRIPT_CASES = [ +TEST_EXPORT_CASES = [ [model, *TEST_CASE_1] for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] ] @@ -310,11 +310,11 @@ def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_sha } equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict) - @parameterized.expand(TEST_SCRIPT_CASES) - def test_script(self, model, input_param, input_shape, expected_shape): + @parameterized.expand(TEST_EXPORT_CASES) + def test_export(self, model, input_param, input_shape, expected_shape): net = model(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) @SkipIfNoModule("hf_hub_download") diff --git a/tests/networks/nets/test_segresnet.py b/tests/networks/nets/test_segresnet.py index 1536d33853..e6fe0e926a 100644 --- a/tests/networks/nets/test_segresnet.py +++ b/tests/networks/nets/test_segresnet.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.nets import SegResNet, SegResNetVAE from monai.utils import UpsampleMode -from tests.test_utils import dict_product, test_script_save +from tests.test_utils import dict_product, test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -77,11 +77,11 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SegResNet(spatial_dims=4) - def test_script(self): + def test_export(self): input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET[0] net = SegResNet(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) class TestResNetVAE(unittest.TestCase): @@ -92,11 +92,11 @@ def test_vae_shape(self, input_param, input_shape, expected_shape): result, _ = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0] net = SegResNetVAE(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_segresnet_ds.py b/tests/networks/nets/test_segresnet_ds.py index 4a2cf40e6f..79a2405331 100644 --- a/tests/networks/nets/test_segresnet_ds.py +++ b/tests/networks/nets/test_segresnet_ds.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import SegResNetDS, SegResNetDS2 -from tests.test_utils import dict_product, test_script_save +from tests.test_utils import dict_product, test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -135,11 +135,11 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SegResNetDS2(spatial_dims=4) - def test_script(self): + def test_export(self): input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0] net = SegResNetDS(**input_param) test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_senet.py b/tests/networks/nets/test_senet.py index 90d711d0d9..1bed2e8214 100644 --- a/tests/networks/nets/test_senet.py +++ b/tests/networks/nets/test_senet.py @@ -24,7 +24,7 @@ from monai.networks import eval_mode from monai.networks.nets import SENet, SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101 from monai.utils import optional_import -from tests.test_utils import test_is_quick, test_pretrained_networks, test_script_save, testing_data_config +from tests.test_utils import test_export_save, test_is_quick, test_pretrained_networks, testing_data_config if TYPE_CHECKING: import pretrainedmodels @@ -69,10 +69,10 @@ def test_senet_shape(self, net, net_args): self.assertEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_script(self, net, net_args): + def test_export(self, net, net_args): net = net(**net_args) input_data = torch.randn(2, 2, 64, 64, 64) - test_script_save(net, input_data) + test_export_save(net, input_data) class TestPretrainedSENET(unittest.TestCase): diff --git a/tests/networks/nets/test_unet.py b/tests/networks/nets/test_unet.py index 7a6d0e98bb..6cd87d35d9 100644 --- a/tests/networks/nets/test_unet.py +++ b/tests/networks/nets/test_unet.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.layers import Act, Norm from monai.networks.nets import UNet -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -172,14 +172,14 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = UNet( spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 ) test_data = torch.randn(16, 1, 32, 32) - test_script_save(net, test_data) + test_export_save(net, test_data) - def test_script_without_running_stats(self): + def test_export_without_running_stats(self): net = UNet( spatial_dims=2, in_channels=1, @@ -190,7 +190,7 @@ def test_script_without_running_stats(self): norm=("batch", {"track_running_stats": False}), ) test_data = torch.randn(16, 1, 16, 4) - test_script_save(net, test_data) + test_export_save(net, test_data) def test_ill_input_shape(self): net = UNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2)) diff --git a/tests/networks/nets/test_unetr.py b/tests/networks/nets/test_unetr.py index 5d4faa3979..fc321af1b1 100644 --- a/tests/networks/nets/test_unetr.py +++ b/tests/networks/nets/test_unetr.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.unetr import UNETR -from tests.test_utils import dict_product, skip_if_quick, test_script_save +from tests.test_utils import dict_product, skip_if_quick, test_export_save TEST_CASE_UNETR = [ [ @@ -115,14 +115,11 @@ def test_ill_arg(self): ) @parameterized.expand(TEST_CASE_UNETR) - def test_script(self, input_param, input_shape, _): + def test_export(self, input_param, input_shape, _): net = UNETR(**(input_param)) net.eval() - with torch.no_grad(): - torch.jit.script(net) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_varautoencoder.py b/tests/networks/nets/test_varautoencoder.py index 459c537c55..641edc22f5 100644 --- a/tests/networks/nets/test_varautoencoder.py +++ b/tests/networks/nets/test_varautoencoder.py @@ -19,7 +19,7 @@ from monai.networks import eval_mode from monai.networks.layers import Act from monai.networks.nets import VarAutoEncoder -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -115,12 +115,12 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = VarAutoEncoder( spatial_dims=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2) ) test_data = torch.randn(2, 1, 32, 32) - test_script_save(net, test_data, rtol=1e-3, atol=1e-3) + test_export_save(net, test_data, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/networks/nets/test_vit.py b/tests/networks/nets/test_vit.py index 54ad3f863e..17db8df411 100644 --- a/tests/networks/nets/test_vit.py +++ b/tests/networks/nets/test_vit.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT -from tests.test_utils import dict_product, skip_if_quick, test_script_save +from tests.test_utils import dict_product, skip_if_quick, test_export_save TEST_CASE_Vit = [ ( @@ -99,14 +99,11 @@ def test_ill_arg( ) @parameterized.expand(TEST_CASE_Vit[:1]) - def test_script(self, input_param, input_shape, _): + def test_export(self, input_param, input_shape, _): net = ViT(**(input_param)) net.eval() - with torch.no_grad(): - torch.jit.script(net) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) + test_export_save(net, test_data) def test_access_attn_matrix(self): # input format diff --git a/tests/networks/nets/test_vnet.py b/tests/networks/nets/test_vnet.py index 6c93893480..66b103226a 100644 --- a/tests/networks/nets/test_vnet.py +++ b/tests/networks/nets/test_vnet.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets import VNet -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -71,10 +71,10 @@ def test_vnet_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - def test_script(self): + def test_export(self): net = VNet(spatial_dims=3, in_channels=1, out_channels=3, dropout_dim=3) test_data = torch.randn(1, 1, 32, 32, 32) - test_script_save(net, test_data) + test_export_save(net, test_data) if __name__ == "__main__": diff --git a/tests/networks/nets/test_vqvae.py b/tests/networks/nets/test_vqvae.py index 13a34f9657..6c0503fc01 100644 --- a/tests/networks/nets/test_vqvae.py +++ b/tests/networks/nets/test_vqvae.py @@ -125,8 +125,8 @@ def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) - # Removed this test case since TorchScript currently does not support activation checkpoint. - # def test_script(self): + # Removed this test case since torch.export does not support activation checkpoint. + # def test_export(self): # net = VQVAE( # spatial_dims=2, # in_channels=1, @@ -141,7 +141,7 @@ def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): # ddp_sync=False, # ) # test_data = torch.randn(1, 1, 16, 16) - # test_script_save(net, test_data) + # test_export_save(net, test_data) def test_channels_not_same_size_of_num_res_channels(self): with self.assertRaises(ValueError): diff --git a/tests/networks/test_convert_to_export.py b/tests/networks/test_convert_to_export.py new file mode 100644 index 0000000000..9df3bfe8df --- /dev/null +++ b/tests/networks/test_convert_to_export.py @@ -0,0 +1,79 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest + +import torch + +from monai.networks import convert_to_export +from monai.networks.nets import UNet +from monai.utils.module import pytorch_after + + +@unittest.skipUnless(pytorch_after(2, 6), "torch.export requires PyTorch >= 2.6") +class TestConvertToExport(unittest.TestCase): + def test_basic_export(self): + """Export a UNet and verify output matches.""" + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + with tempfile.TemporaryDirectory() as tempdir: + exported = convert_to_export( + model=model, + filename_or_obj=os.path.join(tempdir, "model.pt2"), + verify=True, + inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)], + device="cpu", + rtol=1e-3, + atol=1e-4, + ) + self.assertIsInstance(exported, torch.export.ExportedProgram) + + def test_export_without_save(self): + """Export a model without saving to disk.""" + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + exported = convert_to_export(model=model, inputs=[torch.randn((2, 1, 32, 32))]) + self.assertIsInstance(exported, torch.export.ExportedProgram) + out = exported.module()(torch.randn(2, 1, 32, 32)) + self.assertEqual(out.shape, torch.Size([2, 3, 32, 32])) + + def test_missing_inputs_raises(self): + """Verify that missing inputs raise ValueError.""" + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + with self.assertRaises(ValueError): + convert_to_export(model=model) + + @unittest.skipUnless(pytorch_after(2, 9), "torch.export.Dim.DYNAMIC requires PyTorch >= 2.9") + def test_export_with_dynamic_shapes(self): + """Export with dynamic batch dimension.""" + model = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + dynamic = torch.export.Dim.DYNAMIC + static = torch.export.Dim.STATIC + exported = convert_to_export( + model=model, inputs=[torch.randn((2, 1, 32, 32))], dynamic_shapes=((dynamic, static, dynamic, dynamic),) + ) + # Verify works with different batch size and spatial dims + out = exported.module()(torch.randn(4, 1, 64, 64)) + self.assertEqual(out.shape, torch.Size([4, 3, 64, 64])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/test_varnet.py b/tests/networks/test_varnet.py index 9cadf15ce5..8810baf733 100644 --- a/tests/networks/test_varnet.py +++ b/tests/networks/test_varnet.py @@ -20,7 +20,7 @@ from monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet from monai.apps.reconstruction.networks.nets.varnet import VariationalNetworkModel from monai.networks import eval_mode -from tests.test_utils import test_script_save +from tests.test_utils import test_export_save device = torch.device("cuda" if torch.cuda.is_available() else "cpu") coil_sens_model = CoilSensitivityModel(spatial_dims=2, features=[8, 16, 32, 64, 128, 8]) @@ -45,7 +45,7 @@ def test_shape(self, coil_sens_model, refinement_model, num_cascades, input_shap self.assertEqual(result.shape, expected_shape) @parameterized.expand(TESTS) - def test_script(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape): + def test_export(self, coil_sens_model, refinement_model, num_cascades, input_shape, _expected_shape): net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades) mask_shape = [1 for _ in input_shape] @@ -55,7 +55,7 @@ def test_script(self, coil_sens_model, refinement_model, num_cascades, input_sha test_data = torch.randn(input_shape) - test_script_save(net, test_data, mask.bool()) + test_export_save(net, test_data, mask.bool()) if __name__ == "__main__": diff --git a/tests/test_utils.py b/tests/test_utils.py index e365237401..f0f29c1062 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -47,7 +47,8 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d from monai.data.meta_tensor import MetaTensor, get_track_meta -from monai.networks import convert_to_onnx, convert_to_torchscript +from monai.networks import convert_to_export, convert_to_onnx, convert_to_torchscript +from monai.networks.utils import _recursive_to from monai.utils import optional_import from monai.utils.misc import MONAIEnvVars from monai.utils.module import compute_capabilities_after, pytorch_after @@ -773,6 +774,36 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): ) +def test_export_save(net, *inputs, dynamic_shapes=None, rtol=1e-4, atol=0.0): + """ + Test the ability to save ``net`` as a ``torch.export`` ``.pt2`` object, reload it, and apply inference. + The value ``inputs`` is forward-passed through the original and loaded copy of the network and their + results returned. The forward pass for both is done without gradient accumulation. + + Requires PyTorch >= 2.6.0. Skips silently on older versions. + """ + if not pytorch_after(2, 6): + return + device = "cpu" + # Ensure model and inputs are on CPU to avoid device mismatches in exported constants + net = net.to(device) + inputs = tuple(_recursive_to(i, device) for i in inputs) + with tempfile.TemporaryDirectory() as tempdir: + convert_to_export( + model=net, + filename_or_obj=os.path.join(tempdir, "model.pt2"), + verify=True, + inputs=inputs, + dynamic_shapes=dynamic_shapes, + device=device, + rtol=rtol, + atol=atol, + ) + + +test_export_save.__test__ = False # type: ignore[attr-defined] # prevent pytest from collecting this helper + + def test_onnx_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): """ Test the ability to save `net` in ONNX format, reload it and validate with runtime.