-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Torchscript deprecation #8777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aymuos15
wants to merge
9
commits into
Project-MONAI:dev
Choose a base branch
from
aymuos15:torchscript_deprecation
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,164
−348
Open
Torchscript deprecation #8777
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
2db1d8f
Add torch.export save/load infrastructure and deprecate TorchScript u…
aymuos15 a769c7e
Add convert_to_export() and deprecate convert_to_torchscript()
aymuos15 be18b63
Add export_checkpoint() bundle CLI and update bundle load()
aymuos15 45c20ef
Remove torch.jit constructs from network architectures
aymuos15 fc612cb
Fix torch.export compatibility in loss functions
aymuos15 53f8c1d
Migrate tests from TorchScript to torch.export
aymuos15 2e8a7ab
Update bundle docs to reference torch.export and .pt2 format
aymuos15 cbc0346
Fix CI failures and address CodeRabbit review comments
aymuos15 75d5732
Merge branch 'dev' into torchscript_deprecation
aymuos15 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,32 +85,40 @@ def __init__( | |
| self.spatial_dims = spatial_dims | ||
| self.coil_dim = coil_dim | ||
|
|
||
| def get_fully_sampled_region(self, mask: Tensor) -> tuple[int, int]: | ||
| def _compute_acr_mask(self, mask: Tensor) -> Tensor: | ||
| """ | ||
| Extracts the size of the fully-sampled part of the kspace. Note that when a kspace | ||
| is under-sampled, a part of its center is fully sampled. This part is called the Auto | ||
| Calibration Region (ACR). ACR is used for sensitivity map computation. | ||
| Compute a boolean mask for the Auto Calibration Region (ACR) — the contiguous | ||
| fully-sampled center of the k-space sampling mask. | ||
|
|
||
| Uses pure tensor operations (``cumprod``) instead of while-loops so that | ||
| the computation is compatible with ``torch.export``. | ||
|
|
||
| Args: | ||
| mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension | ||
| mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension. | ||
|
|
||
| Returns: | ||
| A tuple containing | ||
| (1) left index of the region | ||
| (2) right index of the region | ||
|
|
||
| Note: | ||
| Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right | ||
| indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12. | ||
| A boolean tensor broadcastable to ``masked_kspace`` that is True inside the ACR. | ||
| """ | ||
| left = right = mask.shape[-2] // 2 | ||
| while mask[..., right, :]: | ||
| right += 1 | ||
| s_len = mask.shape[-2] | ||
| center = s_len // 2 | ||
|
|
||
| # Flatten to 1-D along the sampling axis | ||
| m = mask.reshape(-1)[:s_len].bool() | ||
|
|
||
|
Comment on lines
+106
to
+107
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ACR derivation currently depends on the first flattened slice
Proposed fix- m = mask.reshape(-1)[:s_len].bool()
+ # Collapse non-sampling dims and keep only frequencies sampled across all of them.
+ m = mask.reshape(-1, s_len).all(dim=0)🤖 Prompt for AI Agents |
||
| # Count consecutive True values from center going right | ||
| right_count = torch.cumprod(m[center:].int(), dim=0).sum() | ||
| # Count consecutive True values from center going left (including center) | ||
| left_count = torch.cumprod(m[: center + 1].flip(0).int(), dim=0).sum() | ||
| num_low_freqs = left_count + right_count - 1 | ||
|
|
||
| while mask[..., left, :]: | ||
| left -= 1 | ||
| # Build a boolean mask over the sampling dimension | ||
| start = (s_len - num_low_freqs + 1) // 2 | ||
| freq_idx = torch.arange(s_len, device=mask.device) | ||
| acr_1d = (freq_idx >= start) & (freq_idx < start + num_low_freqs) | ||
|
|
||
| return left + 1, right | ||
| # Reshape to (..., S, 1) so it broadcasts against masked_kspace | ||
| result: Tensor = acr_1d.view(*([1] * (mask.ndim - 2)), s_len, 1) | ||
| return result | ||
|
|
||
| def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor: | ||
| """ | ||
|
|
@@ -122,13 +130,10 @@ def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor: | |
| Returns: | ||
| predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data. | ||
| """ | ||
| left, right = self.get_fully_sampled_region(mask) | ||
| num_low_freqs = right - left # size of the fully-sampled center | ||
| acr_mask = self._compute_acr_mask(mask) | ||
|
|
||
| # take out the fully-sampled region and set the rest of the data to zero | ||
| x = torch.zeros_like(masked_kspace) | ||
| start = (mask.shape[-2] - num_low_freqs + 1) // 2 # this marks the start of center extraction | ||
| x[..., start : start + num_low_freqs, :] = masked_kspace[..., start : start + num_low_freqs, :] | ||
| x = masked_kspace * acr_mask | ||
|
|
||
| # apply inverse fourier to the extracted fully-sampled data | ||
| x = ifftn_centered_t(x, spatial_dims=self.spatial_dims, is_complex=True) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| ckpt_export, | ||
| download, | ||
| download_large_files, | ||
| export_checkpoint, | ||
| init_bundle, | ||
| onnx_export, | ||
| run, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.