From 999ad282a87ab4b0273d69c3011fbae5faa9c43e Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 13:12:12 +0000 Subject: [PATCH 01/10] .to() for constraints; expose angle tolerance --- torch_sim/constraints.py | 48 +++++++++++++++++++++++++++++++++++++++- torch_sim/state.py | 6 +++++ torch_sim/symmetrize.py | 25 +++++++++++++++------ 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 91323880..3c60cd2b 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -151,6 +151,18 @@ def merge(cls, constraints: list[Constraint]) -> Self: constraints: Constraints to merge (all same type, already reindexed) """ + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Self: + """Return a copy with all internal tensors moved to *device*/*dtype*. + + Float tensors are cast to *dtype*; integer/bool tensors are only moved + to *device*. Subclasses with tensor attributes must override this. + """ + return self + def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor: """Cumulative sum with a leading zero, e.g. [3, 2, 4] -> [0, 3, 5, 9].""" @@ -272,6 +284,14 @@ def merge(cls, constraints: list[Constraint]) -> Self: ) return cls(torch.cat([constraint.atom_idx for constraint in atom_constraints])) + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, # noqa: ARG002 + ) -> Self: + """Return a copy with atom indices moved to *device*.""" + return type(self)(self.atom_idx.to(device=device)) + class SystemConstraint(Constraint): """Base class for constraints that act on specific system indices. @@ -371,6 +391,14 @@ def merge(cls, constraints: list[Constraint]) -> Self: torch.cat([constraint.system_idx for constraint in system_constraints]) ) + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, # noqa: ARG002 + ) -> Self: + """Return a copy with system indices moved to *device*.""" + return type(self)(self.system_idx.to(device=device)) + def merge_constraints( constraint_lists: list[list[Constraint]], @@ -612,6 +640,17 @@ def __repr__(self) -> str: """String representation of the constraint.""" return f"FixCom(system_idx={self.system_idx})" + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Self: + """Return a copy with tensors moved to *device*/*dtype*.""" + new = type(self)(self.system_idx.to(device=device)) + if self.coms is not None: + new.coms = self.coms.to(device=device, dtype=dtype) + return new + def count_degrees_of_freedom( state: SimState, constraints: list[Constraint] | None = None @@ -801,6 +840,7 @@ def from_state( adjust_positions: bool = True, adjust_cell: bool = True, refine_symmetry_state: bool = True, + angle_tolerance: float | None = None, ) -> Self: """Create from SimState, optionally refining to ideal symmetry first. @@ -814,6 +854,8 @@ def from_state( adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. refine_symmetry_state: Whether to refine positions/cell to ideal values. + angle_tolerance: Angle tolerance in degrees for moyopy symmetry + detection. If None, moyopy uses its default behaviour. """ try: import moyopy # noqa: F401 @@ -839,11 +881,15 @@ def from_state( pos, nums, symprec=symprec, + angle_tolerance=angle_tolerance, ) state.cell[sys_idx] = cell.mT # row→column vector convention state.positions[start:end] = pos else: - rots, smap = prep_symmetry(cell, pos, nums, symprec=symprec) + rots, smap = prep_symmetry( + cell, pos, nums, symprec=symprec, + angle_tolerance=angle_tolerance, + ) rotations.append(rots) symm_maps.append(smap) diff --git a/torch_sim/state.py b/torch_sim/state.py index 390b55fd..f5e3aea7 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -802,6 +802,12 @@ def _state_to_device[T: SimState]( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + + if attrs.get("_constraints"): + attrs["_constraints"] = [ + c.to(device=device, dtype=dtype) for c in attrs["_constraints"] + ] + return type(state)(**attrs) diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 8677e8f0..9e3eaa15 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -22,6 +22,7 @@ def _moyo_dataset( frac_pos: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 1e-4, + angle_tolerance: float | None = None, ) -> MoyoDataset: """Get MoyoDataset from cell, fractional positions, and atomic numbers.""" from moyopy import Cell, MoyoDataset @@ -31,7 +32,7 @@ def _moyo_dataset( positions=frac_pos.detach().cpu().tolist(), numbers=atomic_numbers.detach().cpu().int().tolist(), ) - return MoyoDataset(moyo_cell, symprec=symprec) + return MoyoDataset(moyo_cell, symprec=symprec, angle_tolerance=angle_tolerance) def _extract_symmetry_ops( @@ -51,13 +52,19 @@ def _extract_symmetry_ops( return rotations, translations -def get_symmetry_datasets(state: SimState, symprec: float = 1e-4) -> list[MoyoDataset]: +def get_symmetry_datasets( + state: SimState, + symprec: float = 1e-4, + angle_tolerance: float | None = None, +) -> list[MoyoDataset]: """Get MoyoDataset for each system in a SimState.""" datasets = [] for single in state.split(): cell = single.row_vector_cell[0] frac = single.positions @ torch.linalg.inv(cell) - datasets.append(_moyo_dataset(cell, frac, single.atomic_numbers, symprec)) + datasets.append( + _moyo_dataset(cell, frac, single.atomic_numbers, symprec, angle_tolerance) + ) return datasets @@ -105,6 +112,7 @@ def prep_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 1e-4, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Get symmetry rotations and atom mappings for a structure. @@ -112,7 +120,7 @@ def prep_symmetry( (rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms). """ frac_pos = positions @ torch.linalg.inv(cell) - dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance) rotations, translations = _extract_symmetry_ops(dataset, cell.dtype, cell.device) return rotations, build_symmetry_map(rotations, translations, frac_pos) @@ -122,6 +130,7 @@ def _refine_symmetry_impl( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Core refinement returning all intermediate data for reuse. @@ -130,7 +139,7 @@ def _refine_symmetry_impl( """ dtype, device = cell.dtype, cell.device frac_pos = positions @ torch.linalg.inv(cell) - dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance) rotations, translations = _extract_symmetry_ops(dataset, dtype, device) n_ops, n_atoms = rotations.shape[0], positions.shape[0] @@ -165,6 +174,7 @@ def refine_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Symmetrize cell and positions according to the detected space group. @@ -175,7 +185,7 @@ def refine_symmetry( (symmetrized_cell, symmetrized_positions) as row vectors. """ new_cell, new_positions, _rotations, _translations = _refine_symmetry_impl( - cell, positions, atomic_numbers, symprec + cell, positions, atomic_numbers, symprec, angle_tolerance ) return new_cell, new_positions @@ -185,6 +195,7 @@ def refine_and_prep_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Refine symmetry and get ops/mappings in a single moyopy call. @@ -195,7 +206,7 @@ def refine_and_prep_symmetry( (refined_cell, refined_positions, rotations, symm_map) """ new_cell, new_positions, rotations, translations = _refine_symmetry_impl( - cell, positions, atomic_numbers, symprec + cell, positions, atomic_numbers, symprec, angle_tolerance ) # Build symm_map on the final refined fractional coordinates refined_frac = new_positions @ torch.linalg.inv(new_cell) From 1bc6cfa05d5207b577144172320c576ecc2e7ee2 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 13:34:08 +0000 Subject: [PATCH 02/10] add tests --- tests/test_constraints.py | 66 +++++++++++++++++++++++++++++++++++++++ torch_sim/constraints.py | 20 ++++++++++++ 2 files changed, 86 insertions(+) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index d2c015e1..918c42ec 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1210,3 +1210,69 @@ def test_fix_com_system_idx_remapped_on_reordered_slice( c = sliced.constraints[0] assert isinstance(c, FixCom) assert sorted(c.system_idx.tolist()) == [0, 1] + + +class TestConstraintToDeviceDtype: + """Test that state.to() propagates device/dtype to constraint tensors.""" + + def test_fix_atoms_dtype_propagation( + self, ar_supercell_sim_state: ts.SimState + ) -> None: + """FixAtoms indices should be moved to the new device by state.to().""" + indices = torch.tensor([0, 3, 5], dtype=torch.long) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices)] + new_state = ar_supercell_sim_state.to(dtype=torch.float32) + + c = new_state.constraints[0] + assert isinstance(c, FixAtoms) + assert torch.equal(c.atom_idx, indices) + # dtype change should not affect integer indices, but the constraint + # object must be a distinct copy + assert c is not ar_supercell_sim_state.constraints[0] + + def test_fix_com_dtype_propagation( + self, ar_supercell_sim_state: ts.SimState + ) -> None: + """FixCom's cached coms tensor should follow state dtype changes.""" + ar_supercell_sim_state.constraints = [FixCom([0])] + # Trigger lazy COM initialisation + ar_supercell_sim_state.set_constrained_positions( + ar_supercell_sim_state.positions.clone() + ) + assert ar_supercell_sim_state.constraints[0].coms is not None + + new_state = ar_supercell_sim_state.to(dtype=torch.float32) + c = new_state.constraints[0] + assert isinstance(c, FixCom) + assert c.coms is not None + assert c.coms.dtype == torch.float32 + + @pytest.mark.parametrize("target_dtype", [torch.float32, torch.float64]) + def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None: + """FixSymmetry rotations and reference_cells must follow dtype changes.""" + rotations = [torch.eye(3, dtype=torch.float64).unsqueeze(0)] + symm_maps = [torch.zeros(1, 2, dtype=torch.long)] + ref_cells = [torch.eye(3, dtype=torch.float64)] + + state = ts.SimState( + positions=torch.zeros(2, 3, dtype=torch.float64), + masses=torch.ones(2, dtype=torch.float64), + cell=torch.eye(3, dtype=torch.float64).unsqueeze(0) * 5.0, + pbc=True, + atomic_numbers=torch.tensor([14, 14]), + system_idx=torch.zeros(2, dtype=torch.long), + ) + state.constraints = [ + FixSymmetry(rotations, symm_maps, reference_cells=ref_cells) + ] + + new_state = state.to(dtype=target_dtype) + c = new_state.constraints[0] + assert isinstance(c, FixSymmetry) + assert c.rotations[0].dtype == target_dtype + assert c.reference_cells[0].dtype == target_dtype + # integer symm_maps must stay long + assert c.symm_maps[0].dtype == torch.long + # original constraint unchanged + orig = state.constraints[0] + assert orig.rotations[0].dtype == torch.float64 diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 3c60cd2b..f5e47f1e 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -1019,6 +1019,26 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 max_cumulative_strain=self.max_cumulative_strain, ) + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Self: + """Return a copy with tensors moved to *device*/*dtype*.""" + return type(self)( + [r.to(device=device, dtype=dtype) for r in self.rotations], + [s.to(device=device) for s in self.symm_maps], + self.system_idx.to(device=device), + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + reference_cells=( + [c.to(device=device, dtype=dtype) for c in self.reference_cells] + if self.reference_cells is not None + else None + ), + max_cumulative_strain=self.max_cumulative_strain, + ) + @classmethod def merge(cls, constraints: list[Constraint]) -> Self: """Merge by concatenating rotations, symm_maps, and system indices.""" From 2dd3b66debc5f7d24ea29f4b0a4cce107ceba7d3 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 13:38:35 +0000 Subject: [PATCH 03/10] style --- tests/test_constraints.py | 8 ++------ torch_sim/constraints.py | 5 ++++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 918c42ec..f869da4a 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1230,9 +1230,7 @@ def test_fix_atoms_dtype_propagation( # object must be a distinct copy assert c is not ar_supercell_sim_state.constraints[0] - def test_fix_com_dtype_propagation( - self, ar_supercell_sim_state: ts.SimState - ) -> None: + def test_fix_com_dtype_propagation(self, ar_supercell_sim_state: ts.SimState) -> None: """FixCom's cached coms tensor should follow state dtype changes.""" ar_supercell_sim_state.constraints = [FixCom([0])] # Trigger lazy COM initialisation @@ -1262,9 +1260,7 @@ def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None atomic_numbers=torch.tensor([14, 14]), system_idx=torch.zeros(2, dtype=torch.long), ) - state.constraints = [ - FixSymmetry(rotations, symm_maps, reference_cells=ref_cells) - ] + state.constraints = [FixSymmetry(rotations, symm_maps, reference_cells=ref_cells)] new_state = state.to(dtype=target_dtype) c = new_state.constraints[0] diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index f5e47f1e..0cd54591 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -887,7 +887,10 @@ def from_state( state.positions[start:end] = pos else: rots, smap = prep_symmetry( - cell, pos, nums, symprec=symprec, + cell, + pos, + nums, + symprec=symprec, angle_tolerance=angle_tolerance, ) From 00b3b34f871cf4b69201a5c25c8ed1140eef4887 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 13:40:32 +0000 Subject: [PATCH 04/10] style --- torch_sim/constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 0cd54591..c793482b 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -153,8 +153,8 @@ def merge(cls, constraints: list[Constraint]) -> Self: def to( self, - device: torch.device | None = None, - dtype: torch.dtype | None = None, + device: torch.device | None = None, # noqa: ARG002 + dtype: torch.dtype | None = None, # noqa: ARG002 ) -> Self: """Return a copy with all internal tensors moved to *device*/*dtype*. From ad3cf04b258fac72b1a9971bdfb12f00f902dc4a Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 14:16:06 +0000 Subject: [PATCH 05/10] move all floating point tensors to dtype --- torch_sim/state.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index f5e3aea7..39854f43 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -793,14 +793,16 @@ def _state_to_device[T: SimState]( attrs = state.attributes for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.to(device=device) + if attr_value.is_floating_point() and dtype is not None: + # also move floating point attributes like forces, velocities, etc. to dtype. + attrs[attr_name] = attr_value.to(device=device, dtype=dtype) + else: + # non-floating attributes like system_idx keep their dtype. + attrs[attr_name] = attr_value.to(device=device) elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) if dtype is not None: - attrs["positions"] = attrs["positions"].to(dtype=dtype) - attrs["masses"] = attrs["masses"].to(dtype=dtype) - attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) if attrs.get("_constraints"): From 8ece7acd0cdaf764192d3fa44c9e307019def217 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 14:16:54 +0000 Subject: [PATCH 06/10] style --- torch_sim/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 39854f43..1e63b8d9 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -794,7 +794,8 @@ def _state_to_device[T: SimState]( for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): if attr_value.is_floating_point() and dtype is not None: - # also move floating point attributes like forces, velocities, etc. to dtype. + # also move floating point attributes like forces, velocities, etc. + # to dtype. attrs[attr_name] = attr_value.to(device=device, dtype=dtype) else: # non-floating attributes like system_idx keep their dtype. From a4ab8e29799769146c40b7d4cd1274de7b189aef Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 15:19:14 +0000 Subject: [PATCH 07/10] make .to() abstract --- torch_sim/constraints.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index c793482b..8d4f03d0 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -151,17 +151,17 @@ def merge(cls, constraints: list[Constraint]) -> Self: constraints: Constraints to merge (all same type, already reindexed) """ + @abstractmethod def to( self, - device: torch.device | None = None, # noqa: ARG002 - dtype: torch.dtype | None = None, # noqa: ARG002 + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> Self: """Return a copy with all internal tensors moved to *device*/*dtype*. Float tensors are cast to *dtype*; integer/bool tensors are only moved - to *device*. Subclasses with tensor attributes must override this. + to *device*. """ - return self def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor: From 02731c039c0132568c9440a582128ec423cf399f Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 15:19:31 +0000 Subject: [PATCH 08/10] style --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 1e63b8d9..68b656e0 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -794,7 +794,7 @@ def _state_to_device[T: SimState]( for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): if attr_value.is_floating_point() and dtype is not None: - # also move floating point attributes like forces, velocities, etc. + # also move floating point attributes like forces, velocities, etc. # to dtype. attrs[attr_name] = attr_value.to(device=device, dtype=dtype) else: From 563d991f1c8640c9d40b5bb855f826b22e60d965 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 1 Apr 2026 15:58:40 +0000 Subject: [PATCH 09/10] fix comment --- torch_sim/constraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 8d4f03d0..c6073f01 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -854,7 +854,7 @@ def from_state( adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. refine_symmetry_state: Whether to refine positions/cell to ideal values. - angle_tolerance: Angle tolerance in degrees for moyopy symmetry + angle_tolerance: Angle tolerance in radians for moyopy symmetry detection. If None, moyopy uses its default behaviour. """ try: From 07b1e0246a8ba4eb70d5f62c1298eccf6fdc54fa Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Thu, 2 Apr 2026 09:03:33 +0000 Subject: [PATCH 10/10] style --- tests/test_constraints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index f869da4a..f04eb262 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1266,6 +1266,7 @@ def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None c = new_state.constraints[0] assert isinstance(c, FixSymmetry) assert c.rotations[0].dtype == target_dtype + assert c.reference_cells is not None assert c.reference_cells[0].dtype == target_dtype # integer symm_maps must stay long assert c.symm_maps[0].dtype == torch.long