Skip to content

fix(models): preserve CUDA context for MACE cuEq conversion#114

Open
zubatyuk wants to merge 4 commits into
NVIDIA:mainfrom
zubatyuk:fix/mace-cueq-device
Open

fix(models): preserve CUDA context for MACE cuEq conversion#114
zubatyuk wants to merge 4 commits into
NVIDIA:mainfrom
zubatyuk:fix/mace-cueq-device

Conversation

@zubatyuk

Copy link
Copy Markdown
Collaborator

ALCHEMI Toolkit Pull Request

Description

Fix MACE cuEquivariance checkpoint loading so CUDA device selection and fused cuEq conversion use the device contract expected by mace-torch==0.3.15. The MACE converter enables its fused cuEq operation only when it receives the literal device string "cuda"; passing torch.device("cuda") or torch.device("cuda:1") can silently skip that fused path. This change keeps explicit CUDA indices active through torch.cuda.device(...) while passing "cuda" to the converter.

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Performance improvement
  • Documentation update
  • Refactoring (no functional changes)
  • CI/CD or infrastructure change

Related Issues

None.

Changes Made

  • Normalize checkpoint target devices once and use the normalized torch.device for checkpoint loading and final model placement.
  • Preserve explicit CUDA device indices during cuEquivariance conversion while passing "cuda" to the MACE converter so the fused cuEq path is selected.
  • Warn and skip cuEquivariance conversion when enable_cueq=True is requested for a non-CUDA device, then continue loading the model on the requested device.

Testing

  • Unit tests pass locally (make pytest)
  • Linting passes (make lint)
  • New tests added for new functionality meets coverage expectations?

Targeted MACE wrapper tests passed locally, including no-network tests for CPU warning behavior, device normalization, and explicit CUDA-context conversion, plus slow CUDA cuEq checkpoint tests.

Checklist

  • I have read and understand the Contributing Guidelines
  • I have updated the CHANGELOG.md
  • I have performed a self-review of my code
  • I have added docstrings to new functions/classes
  • I have updated the documentation (if applicable)

Additional Notes

The relevant mace-torch==0.3.15 behavior is that cuEquivariance conversion decides whether to use fused convolution from an exact device == "cuda" check. The wrapper now keeps the requested CUDA device active with PyTorch's CUDA device context, which avoids losing explicit device selection while preventing the fused operation from being skipped by a non-string device argument.

Tip

This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.

zubatyuk added 2 commits June 15, 2026 11:33
Apply only the symmetric part of the displacement tensor when prepare_strain deforms positions and cells.
Document that the returned displacement remains an unconstrained autograd leaf, and add tests for symmetric stress from an asymmetric toy energy plus preservation of a symmetric raw stress response.

Signed-off-by: Roman Zubatyuk <rzubatiuk@nvidia.com>
Normalize MACE checkpoint target devices before loading and final placement.
Run cuEquivariance conversion under the requested CUDA device context while
passing the literal "cuda" device string to the mace-torch converter so its
fused cuEq path is selected. Warn and skip cuEq conversion for non-CUDA
targets.
Add regression coverage for CPU warning behavior, device normalization, and
explicit CUDA-index conversion.

Signed-off-by: Roman Zubatyuk <rzubatiuk@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes MACE cuEquivariance checkpoint loading so explicit CUDA device indices are preserved while still satisfying mace-torch==0.3.15's requirement of an exact "cuda" string to enable the fused convolution path.

  • Device normalization: torch.device(device) is called once at entry, so both str and torch.device inputs are handled uniformly for torch.load and final .to() placement.
  • CUDA context fix: torch.cuda.device(target_device) activates the correct GPU index before calling the converter with device="cuda", ensuring the fused cuEq path fires on the intended device.
  • Non-CUDA guard: When enable_cueq=True is requested on a non-CUDA device, the code now warns and skips conversion rather than silently passing a CPU device to the converter.

Important Files Changed

Filename Overview
nvalchemi/models/mace.py Normalizes device to torch.device upfront, guards cuEq conversion to CUDA-only with a warning, and uses torch.cuda.device() context to preserve explicit CUDA indices while passing the literal "cuda" string to the mace converter.
test/models/test_mace.py Adds three targeted unit tests covering CPU-device warning/skip behavior, device normalization through load and final placement, and CUDA context preservation during cuEq conversion; all tests are well-isolated with monkeypatching.

Reviews (2): Last reviewed commit: "fix(models): report MACE cuEq warning at..." | Re-trigger Greptile

Comment thread nvalchemi/models/mace.py
Set the non-CUDA MACE cuEquivariance warning stacklevel so it points
to the from_checkpoint caller instead of the warning implementation.
Add regression coverage that checks the warning is attributed to the
test call site.

Signed-off-by: Roman Zubatyuk <rzubatiuk@nvidia.com>
Comment thread nvalchemi/models/mace.py
model = _convert_mace_weights(model, return_model=True, device=device)

model = model.to(device)
if target_device.type != "cuda":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether we should just throw an exception, instead of issuing a warning.

If the user wants to use CPU then they should disable cuEq. Conversely, if they want to run with cuEq, then they should make sure to configure the device correctly.

Comment thread nvalchemi/models/mace.py
stacklevel=2,
)
else:
with torch.cuda.device(target_device):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the situation be where you would want to load the weights with cuEq and then move the weights to CPU?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants