fix(models): preserve CUDA context for MACE cuEq conversion#114
fix(models): preserve CUDA context for MACE cuEq conversion#114zubatyuk wants to merge 4 commits into
Conversation
Apply only the symmetric part of the displacement tensor when prepare_strain deforms positions and cells. Document that the returned displacement remains an unconstrained autograd leaf, and add tests for symmetric stress from an asymmetric toy energy plus preservation of a symmetric raw stress response. Signed-off-by: Roman Zubatyuk <rzubatiuk@nvidia.com>
Normalize MACE checkpoint target devices before loading and final placement. Run cuEquivariance conversion under the requested CUDA device context while passing the literal "cuda" device string to the mace-torch converter so its fused cuEq path is selected. Warn and skip cuEq conversion for non-CUDA targets. Add regression coverage for CPU warning behavior, device normalization, and explicit CUDA-index conversion. Signed-off-by: Roman Zubatyuk <rzubatiuk@nvidia.com>
Greptile SummaryThis PR fixes MACE cuEquivariance checkpoint loading so explicit CUDA device indices are preserved while still satisfying
Important Files Changed
Reviews (2): Last reviewed commit: "fix(models): report MACE cuEq warning at..." | Re-trigger Greptile |
Set the non-CUDA MACE cuEquivariance warning stacklevel so it points to the from_checkpoint caller instead of the warning implementation. Add regression coverage that checks the warning is attributed to the test call site. Signed-off-by: Roman Zubatyuk <rzubatiuk@nvidia.com>
| model = _convert_mace_weights(model, return_model=True, device=device) | ||
|
|
||
| model = model.to(device) | ||
| if target_device.type != "cuda": |
There was a problem hiding this comment.
I'm wondering whether we should just throw an exception, instead of issuing a warning.
If the user wants to use CPU then they should disable cuEq. Conversely, if they want to run with cuEq, then they should make sure to configure the device correctly.
| stacklevel=2, | ||
| ) | ||
| else: | ||
| with torch.cuda.device(target_device): |
There was a problem hiding this comment.
What would the situation be where you would want to load the weights with cuEq and then move the weights to CPU?
ALCHEMI Toolkit Pull Request
Description
Fix MACE cuEquivariance checkpoint loading so CUDA device selection and fused cuEq conversion use the device contract expected by mace-torch==0.3.15. The MACE converter enables its fused cuEq operation only when it receives the literal device string
"cuda"; passingtorch.device("cuda")ortorch.device("cuda:1")can silently skip that fused path. This change keeps explicit CUDA indices active throughtorch.cuda.device(...)while passing"cuda"to the converter.Type of Change
Related Issues
None.
Changes Made
torch.devicefor checkpoint loading and final model placement."cuda"to the MACE converter so the fused cuEq path is selected.enable_cueq=Trueis requested for a non-CUDA device, then continue loading the model on the requested device.Testing
make pytest)make lint)Targeted MACE wrapper tests passed locally, including no-network tests for CPU warning behavior, device normalization, and explicit CUDA-context conversion, plus slow CUDA cuEq checkpoint tests.
Checklist
Additional Notes
The relevant mace-torch==0.3.15 behavior is that cuEquivariance conversion decides whether to use fused convolution from an exact
device == "cuda"check. The wrapper now keeps the requested CUDA device active with PyTorch's CUDA device context, which avoids losing explicit device selection while preventing the fused operation from being skipped by a non-string device argument.Tip
This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.