Skip to content

Fix state.to() not propagating dtype/device to constraints & expose angle_tolerance in symmetry pipeline#527

Merged
CompRhys merged 10 commits intoTorchSim:mainfrom
danielzuegner:constraints_device_fix
Apr 2, 2026
Merged

Fix state.to() not propagating dtype/device to constraints & expose angle_tolerance in symmetry pipeline#527
CompRhys merged 10 commits intoTorchSim:mainfrom
danielzuegner:constraints_device_fix

Conversation

@danielzuegner
Copy link
Copy Markdown
Contributor

@danielzuegner danielzuegner commented Apr 1, 2026

Summary

  1. state.to(dtype=...) did not update constraint tensors.
    _state_to_device only converted top-level SimState tensor attributes (positions, cell, masses) but left constraint objects untouched. This meant internal tensors like FixSymmetry.rotations or FixCom.coms stayed on the original dtype/device, causing silent dtype mismatches.
  2. state.to(dtype=...) did not update sub-class tensor attributes like forces, velocities, ....
  3. angle_tolerance could not be passed to moyopy for symmetry detection.
    MoyoDataset supports an angle_tolerance parameter, but it wasn't exposed anywhere in the TorchSim symmetry pipeline.

Changes

Constraint dtype/device propagation

torch_sim/constraints.py, torch_sim/state.py

  • Added Constraint.to(device, dtype) as a base no-op, with concrete overrides on each subclass:
    Class What to() moves
    AtomConstraint atom_idx
    SystemConstraint system_idx
    FixCom system_idx + cached coms tensor
    FixSymmetry rotations (float→dtype), symm_maps (int→device only), reference_cells
  • Each to() follows the same explicit-constructor pattern used by reindex, merge, and select_constraint.
  • _state_to_device now calls .to() on each attached constraint.

_state_to_device dtype handling

torch_sim/state.py

  • The main tensor loop now applies both device and dtype: all floating-point tensors get .to(device=device, dtype=dtype), integer/bool tensors get .to(device=device) only.
  • This replaces the previous approach of hardcoding positions/masses/cell for dtype conversion, so subclass attributes like forces and momenta are now handled automatically.

Expose angle_tolerance

torch_sim/symmetrize.py, torch_sim/constraints.py

  • Added angle_tolerance: float | None = None parameter to:
    • _moyo_dataset
    • get_symmetry_datasets
    • prep_symmetry
    • _refine_symmetry_impl
    • refine_symmetry
    • refine_and_prep_symmetry
  • Threaded through FixSymmetry.from_state().
  • Passed to MoyoDataset(…, angle_tolerance=angle_tolerance).

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

@CompRhys
Copy link
Copy Markdown
Member

CompRhys commented Apr 1, 2026

just the linting and then good to merge?

@danielzuegner
Copy link
Copy Markdown
Contributor Author

Hopefully good to go -- I've fixed the last remaining linting error that was introduced by this PR. The others are pre-existing I believe.

@CompRhys CompRhys merged commit 9a17923 into TorchSim:main Apr 2, 2026
63 of 69 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants