diff --git a/tests/test_constraints.py b/tests/test_constraints.py index d2c015e1..f04eb262 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1210,3 +1210,66 @@ def test_fix_com_system_idx_remapped_on_reordered_slice( c = sliced.constraints[0] assert isinstance(c, FixCom) assert sorted(c.system_idx.tolist()) == [0, 1] + + +class TestConstraintToDeviceDtype: + """Test that state.to() propagates device/dtype to constraint tensors.""" + + def test_fix_atoms_dtype_propagation( + self, ar_supercell_sim_state: ts.SimState + ) -> None: + """FixAtoms indices should be moved to the new device by state.to().""" + indices = torch.tensor([0, 3, 5], dtype=torch.long) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices)] + new_state = ar_supercell_sim_state.to(dtype=torch.float32) + + c = new_state.constraints[0] + assert isinstance(c, FixAtoms) + assert torch.equal(c.atom_idx, indices) + # dtype change should not affect integer indices, but the constraint + # object must be a distinct copy + assert c is not ar_supercell_sim_state.constraints[0] + + def test_fix_com_dtype_propagation(self, ar_supercell_sim_state: ts.SimState) -> None: + """FixCom's cached coms tensor should follow state dtype changes.""" + ar_supercell_sim_state.constraints = [FixCom([0])] + # Trigger lazy COM initialisation + ar_supercell_sim_state.set_constrained_positions( + ar_supercell_sim_state.positions.clone() + ) + assert ar_supercell_sim_state.constraints[0].coms is not None + + new_state = ar_supercell_sim_state.to(dtype=torch.float32) + c = new_state.constraints[0] + assert isinstance(c, FixCom) + assert c.coms is not None + assert c.coms.dtype == torch.float32 + + @pytest.mark.parametrize("target_dtype", [torch.float32, torch.float64]) + def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None: + """FixSymmetry rotations and reference_cells must follow dtype changes.""" + rotations = [torch.eye(3, dtype=torch.float64).unsqueeze(0)] + symm_maps = [torch.zeros(1, 2, dtype=torch.long)] + ref_cells = [torch.eye(3, dtype=torch.float64)] + + state = ts.SimState( + positions=torch.zeros(2, 3, dtype=torch.float64), + masses=torch.ones(2, dtype=torch.float64), + cell=torch.eye(3, dtype=torch.float64).unsqueeze(0) * 5.0, + pbc=True, + atomic_numbers=torch.tensor([14, 14]), + system_idx=torch.zeros(2, dtype=torch.long), + ) + state.constraints = [FixSymmetry(rotations, symm_maps, reference_cells=ref_cells)] + + new_state = state.to(dtype=target_dtype) + c = new_state.constraints[0] + assert isinstance(c, FixSymmetry) + assert c.rotations[0].dtype == target_dtype + assert c.reference_cells is not None + assert c.reference_cells[0].dtype == target_dtype + # integer symm_maps must stay long + assert c.symm_maps[0].dtype == torch.long + # original constraint unchanged + orig = state.constraints[0] + assert orig.rotations[0].dtype == torch.float64 diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 91323880..c6073f01 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -151,6 +151,18 @@ def merge(cls, constraints: list[Constraint]) -> Self: constraints: Constraints to merge (all same type, already reindexed) """ + @abstractmethod + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Self: + """Return a copy with all internal tensors moved to *device*/*dtype*. + + Float tensors are cast to *dtype*; integer/bool tensors are only moved + to *device*. + """ + def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor: """Cumulative sum with a leading zero, e.g. [3, 2, 4] -> [0, 3, 5, 9].""" @@ -272,6 +284,14 @@ def merge(cls, constraints: list[Constraint]) -> Self: ) return cls(torch.cat([constraint.atom_idx for constraint in atom_constraints])) + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, # noqa: ARG002 + ) -> Self: + """Return a copy with atom indices moved to *device*.""" + return type(self)(self.atom_idx.to(device=device)) + class SystemConstraint(Constraint): """Base class for constraints that act on specific system indices. @@ -371,6 +391,14 @@ def merge(cls, constraints: list[Constraint]) -> Self: torch.cat([constraint.system_idx for constraint in system_constraints]) ) + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, # noqa: ARG002 + ) -> Self: + """Return a copy with system indices moved to *device*.""" + return type(self)(self.system_idx.to(device=device)) + def merge_constraints( constraint_lists: list[list[Constraint]], @@ -612,6 +640,17 @@ def __repr__(self) -> str: """String representation of the constraint.""" return f"FixCom(system_idx={self.system_idx})" + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Self: + """Return a copy with tensors moved to *device*/*dtype*.""" + new = type(self)(self.system_idx.to(device=device)) + if self.coms is not None: + new.coms = self.coms.to(device=device, dtype=dtype) + return new + def count_degrees_of_freedom( state: SimState, constraints: list[Constraint] | None = None @@ -801,6 +840,7 @@ def from_state( adjust_positions: bool = True, adjust_cell: bool = True, refine_symmetry_state: bool = True, + angle_tolerance: float | None = None, ) -> Self: """Create from SimState, optionally refining to ideal symmetry first. @@ -814,6 +854,8 @@ def from_state( adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. refine_symmetry_state: Whether to refine positions/cell to ideal values. + angle_tolerance: Angle tolerance in radians for moyopy symmetry + detection. If None, moyopy uses its default behaviour. """ try: import moyopy # noqa: F401 @@ -839,11 +881,18 @@ def from_state( pos, nums, symprec=symprec, + angle_tolerance=angle_tolerance, ) state.cell[sys_idx] = cell.mT # row→column vector convention state.positions[start:end] = pos else: - rots, smap = prep_symmetry(cell, pos, nums, symprec=symprec) + rots, smap = prep_symmetry( + cell, + pos, + nums, + symprec=symprec, + angle_tolerance=angle_tolerance, + ) rotations.append(rots) symm_maps.append(smap) @@ -973,6 +1022,26 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 max_cumulative_strain=self.max_cumulative_strain, ) + def to( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> Self: + """Return a copy with tensors moved to *device*/*dtype*.""" + return type(self)( + [r.to(device=device, dtype=dtype) for r in self.rotations], + [s.to(device=device) for s in self.symm_maps], + self.system_idx.to(device=device), + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + reference_cells=( + [c.to(device=device, dtype=dtype) for c in self.reference_cells] + if self.reference_cells is not None + else None + ), + max_cumulative_strain=self.max_cumulative_strain, + ) + @classmethod def merge(cls, constraints: list[Constraint]) -> Self: """Merge by concatenating rotations, symm_maps, and system indices.""" diff --git a/torch_sim/state.py b/torch_sim/state.py index 390b55fd..68b656e0 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -793,15 +793,24 @@ def _state_to_device[T: SimState]( attrs = state.attributes for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.to(device=device) + if attr_value.is_floating_point() and dtype is not None: + # also move floating point attributes like forces, velocities, etc. + # to dtype. + attrs[attr_name] = attr_value.to(device=device, dtype=dtype) + else: + # non-floating attributes like system_idx keep their dtype. + attrs[attr_name] = attr_value.to(device=device) elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) if dtype is not None: - attrs["positions"] = attrs["positions"].to(dtype=dtype) - attrs["masses"] = attrs["masses"].to(dtype=dtype) - attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + + if attrs.get("_constraints"): + attrs["_constraints"] = [ + c.to(device=device, dtype=dtype) for c in attrs["_constraints"] + ] + return type(state)(**attrs) diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 8677e8f0..9e3eaa15 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -22,6 +22,7 @@ def _moyo_dataset( frac_pos: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 1e-4, + angle_tolerance: float | None = None, ) -> MoyoDataset: """Get MoyoDataset from cell, fractional positions, and atomic numbers.""" from moyopy import Cell, MoyoDataset @@ -31,7 +32,7 @@ def _moyo_dataset( positions=frac_pos.detach().cpu().tolist(), numbers=atomic_numbers.detach().cpu().int().tolist(), ) - return MoyoDataset(moyo_cell, symprec=symprec) + return MoyoDataset(moyo_cell, symprec=symprec, angle_tolerance=angle_tolerance) def _extract_symmetry_ops( @@ -51,13 +52,19 @@ def _extract_symmetry_ops( return rotations, translations -def get_symmetry_datasets(state: SimState, symprec: float = 1e-4) -> list[MoyoDataset]: +def get_symmetry_datasets( + state: SimState, + symprec: float = 1e-4, + angle_tolerance: float | None = None, +) -> list[MoyoDataset]: """Get MoyoDataset for each system in a SimState.""" datasets = [] for single in state.split(): cell = single.row_vector_cell[0] frac = single.positions @ torch.linalg.inv(cell) - datasets.append(_moyo_dataset(cell, frac, single.atomic_numbers, symprec)) + datasets.append( + _moyo_dataset(cell, frac, single.atomic_numbers, symprec, angle_tolerance) + ) return datasets @@ -105,6 +112,7 @@ def prep_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 1e-4, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Get symmetry rotations and atom mappings for a structure. @@ -112,7 +120,7 @@ def prep_symmetry( (rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms). """ frac_pos = positions @ torch.linalg.inv(cell) - dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance) rotations, translations = _extract_symmetry_ops(dataset, cell.dtype, cell.device) return rotations, build_symmetry_map(rotations, translations, frac_pos) @@ -122,6 +130,7 @@ def _refine_symmetry_impl( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Core refinement returning all intermediate data for reuse. @@ -130,7 +139,7 @@ def _refine_symmetry_impl( """ dtype, device = cell.dtype, cell.device frac_pos = positions @ torch.linalg.inv(cell) - dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance) rotations, translations = _extract_symmetry_ops(dataset, dtype, device) n_ops, n_atoms = rotations.shape[0], positions.shape[0] @@ -165,6 +174,7 @@ def refine_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Symmetrize cell and positions according to the detected space group. @@ -175,7 +185,7 @@ def refine_symmetry( (symmetrized_cell, symmetrized_positions) as row vectors. """ new_cell, new_positions, _rotations, _translations = _refine_symmetry_impl( - cell, positions, atomic_numbers, symprec + cell, positions, atomic_numbers, symprec, angle_tolerance ) return new_cell, new_positions @@ -185,6 +195,7 @@ def refine_and_prep_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + angle_tolerance: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Refine symmetry and get ops/mappings in a single moyopy call. @@ -195,7 +206,7 @@ def refine_and_prep_symmetry( (refined_cell, refined_positions, rotations, symm_map) """ new_cell, new_positions, rotations, translations = _refine_symmetry_impl( - cell, positions, atomic_numbers, symprec + cell, positions, atomic_numbers, symprec, angle_tolerance ) # Build symm_map on the final refined fractional coordinates refined_frac = new_positions @ torch.linalg.inv(new_cell)