Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,3 +1210,66 @@ def test_fix_com_system_idx_remapped_on_reordered_slice(
c = sliced.constraints[0]
assert isinstance(c, FixCom)
assert sorted(c.system_idx.tolist()) == [0, 1]


class TestConstraintToDeviceDtype:
"""Test that state.to() propagates device/dtype to constraint tensors."""

def test_fix_atoms_dtype_propagation(
self, ar_supercell_sim_state: ts.SimState
) -> None:
"""FixAtoms indices should be moved to the new device by state.to()."""
indices = torch.tensor([0, 3, 5], dtype=torch.long)
ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices)]
new_state = ar_supercell_sim_state.to(dtype=torch.float32)

c = new_state.constraints[0]
assert isinstance(c, FixAtoms)
assert torch.equal(c.atom_idx, indices)
# dtype change should not affect integer indices, but the constraint
# object must be a distinct copy
assert c is not ar_supercell_sim_state.constraints[0]

def test_fix_com_dtype_propagation(self, ar_supercell_sim_state: ts.SimState) -> None:
"""FixCom's cached coms tensor should follow state dtype changes."""
ar_supercell_sim_state.constraints = [FixCom([0])]
# Trigger lazy COM initialisation
ar_supercell_sim_state.set_constrained_positions(
ar_supercell_sim_state.positions.clone()
)
assert ar_supercell_sim_state.constraints[0].coms is not None

new_state = ar_supercell_sim_state.to(dtype=torch.float32)
c = new_state.constraints[0]
assert isinstance(c, FixCom)
assert c.coms is not None
assert c.coms.dtype == torch.float32

@pytest.mark.parametrize("target_dtype", [torch.float32, torch.float64])
def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None:
"""FixSymmetry rotations and reference_cells must follow dtype changes."""
rotations = [torch.eye(3, dtype=torch.float64).unsqueeze(0)]
symm_maps = [torch.zeros(1, 2, dtype=torch.long)]
ref_cells = [torch.eye(3, dtype=torch.float64)]

state = ts.SimState(
positions=torch.zeros(2, 3, dtype=torch.float64),
masses=torch.ones(2, dtype=torch.float64),
cell=torch.eye(3, dtype=torch.float64).unsqueeze(0) * 5.0,
pbc=True,
atomic_numbers=torch.tensor([14, 14]),
system_idx=torch.zeros(2, dtype=torch.long),
)
state.constraints = [FixSymmetry(rotations, symm_maps, reference_cells=ref_cells)]

new_state = state.to(dtype=target_dtype)
c = new_state.constraints[0]
assert isinstance(c, FixSymmetry)
assert c.rotations[0].dtype == target_dtype
assert c.reference_cells is not None
assert c.reference_cells[0].dtype == target_dtype
# integer symm_maps must stay long
assert c.symm_maps[0].dtype == torch.long
# original constraint unchanged
orig = state.constraints[0]
assert orig.rotations[0].dtype == torch.float64
71 changes: 70 additions & 1 deletion torch_sim/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ def merge(cls, constraints: list[Constraint]) -> Self:
constraints: Constraints to merge (all same type, already reindexed)
"""

@abstractmethod
def to(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Self:
"""Return a copy with all internal tensors moved to *device*/*dtype*.

Float tensors are cast to *dtype*; integer/bool tensors are only moved
to *device*.
"""


def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor:
"""Cumulative sum with a leading zero, e.g. [3, 2, 4] -> [0, 3, 5, 9]."""
Expand Down Expand Up @@ -272,6 +284,14 @@ def merge(cls, constraints: list[Constraint]) -> Self:
)
return cls(torch.cat([constraint.atom_idx for constraint in atom_constraints]))

def to(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None, # noqa: ARG002
) -> Self:
"""Return a copy with atom indices moved to *device*."""
return type(self)(self.atom_idx.to(device=device))


class SystemConstraint(Constraint):
"""Base class for constraints that act on specific system indices.
Expand Down Expand Up @@ -371,6 +391,14 @@ def merge(cls, constraints: list[Constraint]) -> Self:
torch.cat([constraint.system_idx for constraint in system_constraints])
)

def to(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None, # noqa: ARG002
) -> Self:
"""Return a copy with system indices moved to *device*."""
return type(self)(self.system_idx.to(device=device))


def merge_constraints(
constraint_lists: list[list[Constraint]],
Expand Down Expand Up @@ -612,6 +640,17 @@ def __repr__(self) -> str:
"""String representation of the constraint."""
return f"FixCom(system_idx={self.system_idx})"

def to(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Self:
"""Return a copy with tensors moved to *device*/*dtype*."""
new = type(self)(self.system_idx.to(device=device))
if self.coms is not None:
new.coms = self.coms.to(device=device, dtype=dtype)
return new


def count_degrees_of_freedom(
state: SimState, constraints: list[Constraint] | None = None
Expand Down Expand Up @@ -801,6 +840,7 @@ def from_state(
adjust_positions: bool = True,
adjust_cell: bool = True,
refine_symmetry_state: bool = True,
angle_tolerance: float | None = None,
) -> Self:
"""Create from SimState, optionally refining to ideal symmetry first.

Expand All @@ -814,6 +854,8 @@ def from_state(
adjust_positions: Whether to symmetrize position displacements.
adjust_cell: Whether to symmetrize cell/stress adjustments.
refine_symmetry_state: Whether to refine positions/cell to ideal values.
angle_tolerance: Angle tolerance in radians for moyopy symmetry
detection. If None, moyopy uses its default behaviour.
"""
try:
import moyopy # noqa: F401
Expand All @@ -839,11 +881,18 @@ def from_state(
pos,
nums,
symprec=symprec,
angle_tolerance=angle_tolerance,
)
state.cell[sys_idx] = cell.mT # row→column vector convention
state.positions[start:end] = pos
else:
rots, smap = prep_symmetry(cell, pos, nums, symprec=symprec)
rots, smap = prep_symmetry(
cell,
pos,
nums,
symprec=symprec,
angle_tolerance=angle_tolerance,
)

rotations.append(rots)
symm_maps.append(smap)
Expand Down Expand Up @@ -973,6 +1022,26 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002
max_cumulative_strain=self.max_cumulative_strain,
)

def to(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Self:
"""Return a copy with tensors moved to *device*/*dtype*."""
return type(self)(
[r.to(device=device, dtype=dtype) for r in self.rotations],
[s.to(device=device) for s in self.symm_maps],
self.system_idx.to(device=device),
adjust_positions=self.do_adjust_positions,
adjust_cell=self.do_adjust_cell,
reference_cells=(
[c.to(device=device, dtype=dtype) for c in self.reference_cells]
if self.reference_cells is not None
else None
),
max_cumulative_strain=self.max_cumulative_strain,
)

@classmethod
def merge(cls, constraints: list[Constraint]) -> Self:
"""Merge by concatenating rotations, symm_maps, and system indices."""
Expand Down
17 changes: 13 additions & 4 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,15 +793,24 @@ def _state_to_device[T: SimState](
attrs = state.attributes
for attr_name, attr_value in attrs.items():
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.to(device=device)
if attr_value.is_floating_point() and dtype is not None:
# also move floating point attributes like forces, velocities, etc.
# to dtype.
attrs[attr_name] = attr_value.to(device=device, dtype=dtype)
else:
# non-floating attributes like system_idx keep their dtype.
attrs[attr_name] = attr_value.to(device=device)
elif isinstance(attr_value, torch.Generator):
attrs[attr_name] = coerce_prng(attr_value, device)

if dtype is not None:
attrs["positions"] = attrs["positions"].to(dtype=dtype)
attrs["masses"] = attrs["masses"].to(dtype=dtype)
attrs["cell"] = attrs["cell"].to(dtype=dtype)
attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)

if attrs.get("_constraints"):
attrs["_constraints"] = [
c.to(device=device, dtype=dtype) for c in attrs["_constraints"]
]

return type(state)(**attrs)


Expand Down
25 changes: 18 additions & 7 deletions torch_sim/symmetrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _moyo_dataset(
frac_pos: torch.Tensor,
atomic_numbers: torch.Tensor,
symprec: float = 1e-4,
angle_tolerance: float | None = None,
) -> MoyoDataset:
"""Get MoyoDataset from cell, fractional positions, and atomic numbers."""
from moyopy import Cell, MoyoDataset
Expand All @@ -31,7 +32,7 @@ def _moyo_dataset(
positions=frac_pos.detach().cpu().tolist(),
numbers=atomic_numbers.detach().cpu().int().tolist(),
)
return MoyoDataset(moyo_cell, symprec=symprec)
return MoyoDataset(moyo_cell, symprec=symprec, angle_tolerance=angle_tolerance)


def _extract_symmetry_ops(
Expand All @@ -51,13 +52,19 @@ def _extract_symmetry_ops(
return rotations, translations


def get_symmetry_datasets(state: SimState, symprec: float = 1e-4) -> list[MoyoDataset]:
def get_symmetry_datasets(
state: SimState,
symprec: float = 1e-4,
angle_tolerance: float | None = None,
) -> list[MoyoDataset]:
"""Get MoyoDataset for each system in a SimState."""
datasets = []
for single in state.split():
cell = single.row_vector_cell[0]
frac = single.positions @ torch.linalg.inv(cell)
datasets.append(_moyo_dataset(cell, frac, single.atomic_numbers, symprec))
datasets.append(
_moyo_dataset(cell, frac, single.atomic_numbers, symprec, angle_tolerance)
)
return datasets


Expand Down Expand Up @@ -105,14 +112,15 @@ def prep_symmetry(
positions: torch.Tensor,
atomic_numbers: torch.Tensor,
symprec: float = 1e-4,
angle_tolerance: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Get symmetry rotations and atom mappings for a structure.

Returns:
(rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms).
"""
frac_pos = positions @ torch.linalg.inv(cell)
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec)
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance)
rotations, translations = _extract_symmetry_ops(dataset, cell.dtype, cell.device)
return rotations, build_symmetry_map(rotations, translations, frac_pos)

Expand All @@ -122,6 +130,7 @@ def _refine_symmetry_impl(
positions: torch.Tensor,
atomic_numbers: torch.Tensor,
symprec: float = 0.01,
angle_tolerance: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Core refinement returning all intermediate data for reuse.

Expand All @@ -130,7 +139,7 @@ def _refine_symmetry_impl(
"""
dtype, device = cell.dtype, cell.device
frac_pos = positions @ torch.linalg.inv(cell)
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec)
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance)
rotations, translations = _extract_symmetry_ops(dataset, dtype, device)
n_ops, n_atoms = rotations.shape[0], positions.shape[0]

Expand Down Expand Up @@ -165,6 +174,7 @@ def refine_symmetry(
positions: torch.Tensor,
atomic_numbers: torch.Tensor,
symprec: float = 0.01,
angle_tolerance: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Symmetrize cell and positions according to the detected space group.

Expand All @@ -175,7 +185,7 @@ def refine_symmetry(
(symmetrized_cell, symmetrized_positions) as row vectors.
"""
new_cell, new_positions, _rotations, _translations = _refine_symmetry_impl(
cell, positions, atomic_numbers, symprec
cell, positions, atomic_numbers, symprec, angle_tolerance
)
return new_cell, new_positions

Expand All @@ -185,6 +195,7 @@ def refine_and_prep_symmetry(
positions: torch.Tensor,
atomic_numbers: torch.Tensor,
symprec: float = 0.01,
angle_tolerance: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Refine symmetry and get ops/mappings in a single moyopy call.

Expand All @@ -195,7 +206,7 @@ def refine_and_prep_symmetry(
(refined_cell, refined_positions, rotations, symm_map)
"""
new_cell, new_positions, rotations, translations = _refine_symmetry_impl(
cell, positions, atomic_numbers, symprec
cell, positions, atomic_numbers, symprec, angle_tolerance
)
# Build symm_map on the final refined fractional coordinates
refined_frac = new_positions @ torch.linalg.inv(new_cell)
Expand Down
Loading