diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index a4919d65..96b85f51 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -935,8 +935,7 @@ def _get_ase_input( tensor.set_info("quantity", infos["quantity"]) tensor.set_info("unit", infos["unit"]) - tensor = tensor.to(dtype=dtype, device=device) - return tensor + return tensor.to(dtype=dtype, device=device) def _ase_to_torch_data(atoms, dtype, device): diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py new file mode 100644 index 00000000..23dec8cb --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -0,0 +1,380 @@ +from typing import Dict, List, Optional + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from vesin.metatomic import compute_requested_neighbors_from_options + +from metatomic.torch import ( + AtomisticModel, + ModelEvaluationOptions, + ModelOutput, + NeighborListOptions, + System, +) + + +def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + Wrap positions into the periodic cell. + """ + fractional_positions = positions @ cell.inverse() + fractional_positions = fractional_positions - torch.floor(fractional_positions) + wrapped_positions = fractional_positions @ cell + + return wrapped_positions + + +def _check_close_to_cell_boundary( + cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float +) -> torch.Tensor: + """ + Detect atoms that lie within a cutoff distance (in our context, the interaction + range of the model + the skin) from the periodic cell boundaries, + i.e. have interactions with atoms at the opposite end of the cell. + """ + inv_cell = cell.inverse() + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + if heights.min() < (cutoff + skin): + raise ValueError( + "Cell is too small compared to (cutoff + skin) = " + + str(cutoff + skin) + + ". " + "Ensure that all cell vectors are at least this length. Currently, the" + " minimum cell vector length is " + str(heights.min()) + "." + ) + + cutoff = cutoff + skin + normals = recip / norms[:, None] + norm_coords = positions @ normals.T + collisions = torch.hstack( + [norm_coords <= cutoff, norm_coords >= heights - cutoff], + ).to(device=positions.device) + + return collisions[ + :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + ] + + +def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: + """ + Convert boundary-collision flags into a boolean mask over all periodic image + displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi + boundaries, we need the replicas at (1, 0, 0), (0, -1, 0), (1, -1, 0) image cells. + + collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + + returns: [N, 3, 3, 3] boolean mask over image displacements in {0, +1, -1}^3 + 0: no replica needed along that axis + 1: +1 replica needed along that axis (i.e., near low boundary, a replica is + placed just outside the high boundary) + 2: -1 replica needed along that axis (i.e., near high boundary, a replica is + placed just outside the low boundary) + axis order: x, y, z + """ + origin = torch.full( + (len(collisions),), True, dtype=torch.bool, device=collisions.device + ) + axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]]) + ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]]) + azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]]) + # leverage broadcasting + outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] + outs = torch.movedim(outs, -1, 0) + outs[:, 0, 0, 0] = False # not close to any boundary -> no replica needed + return outs.to(device=collisions.device) + + +def _generate_replica_atoms( + types: torch.Tensor, + positions: torch.Tensor, + cell: torch.Tensor, + replicas: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For atoms near the low boundary (x_lo/y_lo/z_lo), generate their images shifted + by +1 cell vector (i.e., placed just outside the high boundary). + For atoms near the high boundary (x_hi/y_hi/z_hi), generate images shifted by −1 + cell vector. + """ + replicas = torch.argwhere(replicas) + replica_idx = replicas[:, 0] + replica_offsets = torch.tensor( + [0, 1, -1], device=positions.device, dtype=positions.dtype + )[replicas[:, 1:]] + replica_positions = positions[replica_idx] + replica_offsets @ cell + + return replica_idx, types[replica_idx], replica_positions + + +def _unfold_system( + metatomic_system: System, cutoff: float, skin: float = 0.5 +) -> System: + """ + Unfold a periodic system by generating replica atoms for those near the cell + boundaries within the specified cutoff distance. + The unfolded system has no periodic boundary conditions. + """ + + if not metatomic_system.pbc.any(): + raise ValueError("Unfolding systems is only supported for periodic systems.") + wrapped_positions = _wrap_positions( + metatomic_system.positions, metatomic_system.cell + ) + collisions = _check_close_to_cell_boundary( + metatomic_system.cell, wrapped_positions, cutoff, skin + ) + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( + metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas + ) + unfolded_types = torch.cat( + [ + metatomic_system.types, + replica_types, + ] + ) + unfolded_positions = torch.cat( + [ + wrapped_positions, + replica_positions, + ] + ) + unfolded_idx = torch.cat( + [ + torch.arange(len(metatomic_system.types), device=metatomic_system.device), + replica_idx, + ] + ) + unfolded_n_atoms = len(unfolded_types) + masses_block = metatomic_system.get_data("masses").block() + velocities_block = metatomic_system.get_data("velocities").block() + unfolded_masses = masses_block.values[unfolded_idx] + unfolded_velocities = velocities_block.values[unfolded_idx] + unfolded_masses_block = TensorBlock( + values=unfolded_masses, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=masses_block.components, + properties=masses_block.properties, + ) + unfolded_velocities_block = TensorBlock( + values=unfolded_velocities, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=velocities_block.components, + properties=velocities_block.properties, + ) + unfolded_system = System( + types=unfolded_types, + positions=unfolded_positions, + cell=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + dtype=unfolded_positions.dtype, + device=metatomic_system.device, + ), + pbc=torch.tensor([False, False, False], device=metatomic_system.device), + ) + unfolded_system.add_data( + "masses", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_masses_block], + ), + ) + unfolded_system.add_data( + "velocities", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_velocities_block], + ), + ) + return unfolded_system.to(metatomic_system.dtype, metatomic_system.device) + + +class HeatFluxWrapper(torch.nn.Module): + """ + A wrapper around an AtomisticModel that computes the heat flux of a system using the + unfolded system approach. The heat flux is computed using the atomic energies (eV), + positions(Å), masses(u), velocities(Å/fs), and the energy gradients. + + The unfolded system is generated by creating replica atoms for those near the cell + boundaries within the interaction range of the model wrapped. The wrapper adds the + heat flux to the model's outputs under the key "extra::heat_flux". + + For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux + for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` + """ + + def __init__(self, model: AtomisticModel, skin: float = 0.5): + """ + :param model: the :py:class:`AtomisticModel` to wrap, which should be able to + compute atomic energies and their gradients with respect to positions + :param skin: the skin parameter for unfolding the system. The wrapper will + generate replica atoms for those within (interaction_range + skin) distance from + the cell boundaries. A skin results in more replica atoms and thus higher + computational cost, but ensures that the heat flux is computed correctly. + """ + super().__init__() + + assert isinstance(model, AtomisticModel) + self._model = model.module + self.skin = skin + self._interaction_range = model.capabilities().interaction_range + + self._requested_neighbor_lists = model.requested_neighbor_lists() + self._requested_inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + + hf_output = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + outputs = model.capabilities().outputs.copy() + outputs["extra::heat_flux"] = hf_output + if outputs["energy"].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) + energies_output = ModelOutput( + quantity="energy", unit=outputs["energy"].unit, per_atom=True + ) + self._unfolded_run_options = ModelEvaluationOptions( + length_unit=model.capabilities().length_unit, + outputs={"energy": energies_output}, + selected_atoms=None, + ) + + @torch.jit.export + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return self._requested_neighbor_lists + + @torch.jit.export + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): + energy_block = self._model( + [system], + self._unfolded_run_options.outputs, + self._unfolded_run_options.selected_atoms, + )["energy"].block(0) + atom_indices = energy_block.samples.column("atom").to(torch.long) + sorted_order = torch.argsort(atom_indices) + atomic_e = energy_block.values.flatten()[sorted_order] + + total_e = atomic_e[:n_atoms].sum() + r_aux = system.positions.detach() + barycenter = (atomic_e[:n_atoms, None] * r_aux[:n_atoms]).sum(dim=0) + + return barycenter, atomic_e, total_e + + def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: + n_atoms = len(system.positions) + unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to( + system.device + ) + compute_requested_neighbors_from_options( + [unfolded_system], + self.requested_neighbor_lists(), + self._unfolded_run_options.length_unit, + False, + ) + velocities: torch.Tensor = ( + unfolded_system.get_data("velocities").block().values.reshape(-1, 3) + ) + masses: torch.Tensor = ( + unfolded_system.get_data("masses").block().values.reshape(-1) + ) + barycenter, atomic_e, total_e = self._barycenter_and_atomic_energies( + unfolded_system, n_atoms + ) + + term1 = torch.zeros( + (3), device=system.positions.device, dtype=system.positions.dtype + ) + for i in range(3): + grad_i = torch.autograd.grad( + [barycenter[i]], + [unfolded_system.positions], + retain_graph=True, + create_graph=False, + )[0] + grad_i = torch.jit._unwrap_optional(grad_i) + term1[i] = (grad_i * velocities).sum() + + go = torch.jit.annotate( + Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] + ) + grads = torch.autograd.grad( + [total_e], + [unfolded_system.positions], + grad_outputs=go, + )[0] + grads = torch.jit._unwrap_optional(grads) + term2 = ( + unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True) + ).sum(dim=0) + + hf_pot = term1 - term2 + + hf_conv = ( + ( + atomic_e[:n_atoms] + + 0.5 + * masses[:n_atoms] + * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 + * 103.6427 # u*A^2/fs^2 to eV + )[:, None] + * velocities[:n_atoms] + ).sum(dim=0) + + return hf_pot + hf_conv + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + outputs_wo_heat_flux = outputs.copy() + if "extra::heat_flux" in outputs: + del outputs_wo_heat_flux["extra::heat_flux"] + results = self._model(systems, outputs_wo_heat_flux, selected_atoms) + + if "extra::heat_flux" not in outputs: + return results + + device = systems[0].device + heat_fluxes: List[torch.Tensor] = [] + for system in systems: + system.positions.requires_grad_(True) + heat_fluxes.append(self._calc_unfolded_heat_flux(system)) + + samples = Labels( + ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) + ) + + hf_block = TensorBlock( + values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), + samples=samples, + components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], + properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), + ) + results["extra::heat_flux"] = TensorMap( + Labels("_", torch.tensor([[0]], device=device)), [hf_block] + ) + return results diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py new file mode 100644 index 00000000..5e5bcd5c --- /dev/null +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -0,0 +1,445 @@ +import metatomic_lj_test +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + System, +) +from metatomic.torch.ase_calculator import MetatomicCalculator +from metatomic.torch.heat_flux import ( + HeatFluxWrapper, + _check_close_to_cell_boundary, + _collisions_to_replicas, + _generate_replica_atoms, + _unfold_system, + _wrap_positions, +) + + +@pytest.fixture +def model(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="eV", + with_extension=False, + ) + + +@pytest.fixture +def model_in_kcal_per_mol(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="kcal/mol", + with_extension=False, + ) + + +@pytest.fixture +def atoms(): + cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]]) + positions = np.array([[3.0, 3.0, 3.0]]) + atoms = Atoms("Ar", scaled_positions=positions, cell=cell, pbc=True).repeat( + (2, 2, 2) + ) + MaxwellBoltzmannDistribution( + atoms, temperature_K=300, rng=np.random.default_rng(42) + ) + return atoms + + +def _make_scalar_tensormap(values: torch.Tensor, property_name: str) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels([property_name], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_velocity_tensormap(values: torch.Tensor) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[ + Labels( + ["xyz"], + torch.arange(3, device=values.device).reshape(-1, 1), + ) + ], + properties=Labels(["velocity"], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_system_with_data(positions: torch.Tensor, cell: torch.Tensor) -> System: + types = torch.tensor([1] * len(positions), dtype=torch.int32) + system = System( + types=types, + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + masses = torch.ones((len(positions), 1), dtype=positions.dtype) + velocities = torch.zeros((len(positions), 3, 1), dtype=positions.dtype) + system.add_data("masses", _make_scalar_tensormap(masses, "mass")) + system.add_data("velocities", _make_velocity_tensormap(velocities)) + return system + + +class _DummyCapabilities: + """Reusable stub for ``model.capabilities()``.""" + + def __init__(self, energy_unit: str = "eV"): + self.outputs = {"energy": ModelOutput(quantity="energy", unit=energy_unit)} + self.length_unit = "A" + self.interaction_range = 1.0 + + +class _ZeroDummyModel: + """Dummy model returning zero energies. Accepts an optional *energy_unit*.""" + + def __init__(self, energy_unit: str = "eV"): + self._capabilities = _DummyCapabilities(energy_unit) + self.module = None + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + +def test_wrap_positions_cubic_matches_expected(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[-0.1, 0.0, 0.0], [2.1, 1.0, -0.5]]) + wrapped = _wrap_positions(positions, cell) + expected = torch.tensor([[1.9, 0.0, 0.0], [0.1, 1.0, 1.5]]) + assert torch.allclose(wrapped, expected) + + +def test_check_close_to_cell_boundary_cubic_axis_order(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.9]]) + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=0.2, skin=0.0) + assert collisions.shape == (1, 6) + assert collisions[0].tolist() == [True, False, False, False, False, True] + + +def test_generate_replica_atoms_cubic_offsets(): + types = torch.tensor([1]) + positions = torch.tensor([[0.1, 1.0, 1.0]]) + cell = torch.eye(3) * 2.0 + collisions = torch.tensor([[True, False, False, False, False, False]]) + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( + types, positions, cell, replicas + ) + assert replica_idx.tolist() == [0] + assert replica_types.tolist() == [1] + assert torch.allclose( + replica_positions, positions + torch.tensor([[2.0, 0.0, 0.0]]) + ) + + +def test_wrap_positions_triclinic_fractional_bounds_and_shift(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + positions = torch.tensor( + [ + [-0.1, 0.0, 0.0], + [2.1, 1.6, -0.5], + [4.2, -0.2, 6.1], + ] + ) + inv_cell = cell.inverse() + wrapped = _wrap_positions(positions, cell) + fractional_before = torch.einsum("iv,vk->ik", positions, inv_cell) + fractional_after = torch.einsum("iv,vk->ik", wrapped, inv_cell) + + assert torch.all(fractional_after >= 0) + assert torch.all(fractional_after < 1) + + delta_frac = fractional_after - fractional_before + rounded = torch.round(delta_frac) + assert torch.allclose(delta_frac, rounded, atol=1e-6, rtol=0) + assert torch.allclose(rounded, -torch.floor(fractional_before), atol=1e-6, rtol=0) + + +def test_check_close_to_cell_boundary_triclinic_targets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + cutoff = 0.2 + inv_cell = cell.inverse() + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + norm_vectors = recip / norms[:, None] + + target = torch.stack( + [ + torch.tensor([0.05, 0.6, 0.6]), + torch.tensor([heights[0] - 0.05, 0.05, heights[2] - 0.1]), + torch.tensor([0.3, heights[1] - 0.05, 0.1]), + ] + ) + positions = target @ torch.inverse(norm_vectors).T + + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=cutoff, skin=0.0) + + expected_low = target <= cutoff + expected_high = target >= heights - cutoff + expected = torch.hstack([expected_low, expected_high]) + expected = expected[:, [0, 3, 1, 4, 2, 5]] + + assert torch.equal(collisions, expected) + + +def test_check_close_to_cell_boundary_raises_on_small_cell(): + cell = torch.eye(3) * 1.0 + positions = torch.zeros((1, 3)) + with pytest.raises(ValueError, match="Cell is too small"): + _check_close_to_cell_boundary(cell, positions, cutoff=0.9, skin=0.2) + + +def test_skin_parameter_affects_collisions(): + """Increasing the skin should extend the effective detection range.""" + cell = torch.eye(3) * 2.0 + # atom at distance 0.3 from the low-x boundary + positions = torch.tensor([[0.3, 1.0, 1.0]]) + + # cutoff=0.2, skin=0.0 → effective range 0.2 < 0.3 → no collision + collisions_no_skin = _check_close_to_cell_boundary( + cell, positions, cutoff=0.2, skin=0.0 + ) + assert not collisions_no_skin.any() + + # cutoff=0.2, skin=0.2 → effective range 0.4 > 0.3 → x_lo collision + collisions_with_skin = _check_close_to_cell_boundary( + cell, positions, cutoff=0.2, skin=0.2 + ) + assert collisions_with_skin[0, 0].item() # x_lo + + +def test_collisions_to_replicas_combines_displacements(): + collisions = torch.tensor([[True, False, False, True, False, False]]) + replicas = _collisions_to_replicas(collisions) + assert replicas.shape == (1, 3, 3, 3) + assert replicas[0, 0, 0, 0].item() is False + + nonzero = torch.nonzero(replicas, as_tuple=False) + expected = { + (0, 1, 0, 0), + (0, 0, 2, 0), + (0, 1, 2, 0), + } + assert {tuple(row.tolist()) for row in nonzero} == expected + + +def test_generate_replica_atoms_triclinic_offsets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + types = torch.tensor([1]) + positions = torch.tensor([[0.2, 0.4, 0.6]]) + collisions = torch.tensor([[True, False, True, False, True, False]]) + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( + types, positions, cell, replicas + ) + + assert replica_idx.tolist() == [0, 0, 0, 0, 0, 0, 0] + assert replica_types.tolist() == [1, 1, 1, 1, 1, 1, 1] + + expected_offsets = [ + cell[0], + cell[1], + cell[2], + cell[0] + cell[1], + cell[0] + cell[2], + cell[1] + cell[2], + cell[0] + cell[1] + cell[2], + ] + expected_positions = [positions[0] + offset for offset in expected_offsets] + + for expected in expected_positions: + assert any( + torch.allclose(expected, actual, atol=1e-6, rtol=0) + for actual in replica_positions + ) + + +def test_unfold_system_adds_replica_and_data(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.0]]) + system = _make_system_with_data(positions, cell) + unfolded = _unfold_system(system, cutoff=0.1) + assert len(unfolded.positions) == 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + + masses = unfolded.get_data("masses").block().values + velocities = unfolded.get_data("velocities").block().values + assert masses.shape[0] == 2 + assert velocities.shape[0] == 2 + + assert torch.allclose(unfolded.positions[0], positions[0]) + assert torch.allclose( + unfolded.positions[1], positions[0] + torch.tensor([2.0, 0.0, 0.0]) + ) + + +def test_unfold_system_no_replicas_for_interior_atoms(): + """Atoms well inside the cell should produce no replicas.""" + cell = torch.eye(3) * 10.0 + positions = torch.tensor([[5.0, 5.0, 5.0], [3.0, 4.0, 6.0]]) + system = _make_system_with_data(positions, cell) + unfolded = _unfold_system(system, cutoff=1.0, skin=0.0) + + assert len(unfolded.positions) == 2 + assert torch.allclose(unfolded.positions, _wrap_positions(positions, cell)) + + +def test_unfold_system_triclinic_cell(): + """Unfolding should work for triclinic cells and propagate all data.""" + cell = torch.tensor( + [ + [4.0, 0.6, 0.4], + [0.2, 3.4, 0.8], + [0.4, 1.0, 3.8], + ] + ) + # One atom near the origin (close to low boundaries), one in the interior + positions = torch.tensor( + [ + [0.05, 0.05, 0.05], + [2.0, 1.7, 1.9], + ] + ) + system = _make_system_with_data(positions, cell) + unfolded = _unfold_system(system, cutoff=0.3, skin=0.0) + + # The near-origin atom should generate at least one replica + assert len(unfolded.positions) > 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + assert torch.all(unfolded.types == 1) + assert unfolded.get_data("masses").block().values.shape[0] == len( + unfolded.positions + ) + assert unfolded.get_data("velocities").block().values.shape[0] == len( + unfolded.positions + ) + + +def test_heat_flux_wrapper_rejects_non_eV_energy(model_in_kcal_per_mol): + with pytest.raises(ValueError, match="energy outputs in eV"): + HeatFluxWrapper(model_in_kcal_per_mol) + + +def test_heat_flux_wrapper_requested_inputs(model): + wrapper = HeatFluxWrapper(model) + requested = wrapper.requested_inputs() + assert set(requested.keys()) == {"masses", "velocities"} + + +@pytest.mark.parametrize("use_script", [True, False]) +def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): + expected = [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]] + + metadata = ModelMetadata() + wrapper = HeatFluxWrapper(model.eval()) + cap = model.capabilities() + outputs = cap.outputs.copy() + outputs["extra::heat_flux"] = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + + new_cap = ModelCapabilities( + outputs=outputs, + atomic_types=cap.atomic_types, + interaction_range=cap.interaction_range, + length_unit=cap.length_unit, + supported_devices=cap.supported_devices, + dtype=cap.dtype, + ) + + if use_script: + wrapper = torch.jit.script(wrapper) + + heat_model = AtomisticModel(wrapper.eval(), metadata, capabilities=new_cap).to( + device="cpu" + ) + calc = MetatomicCalculator( + heat_model, + device="cpu", + additional_outputs={ + "extra::heat_flux": ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + }, + ) + atoms.calc = calc + atoms.get_potential_energy() + assert "extra::heat_flux" in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs["extra::heat_flux"].block().values + assert torch.allclose( + results, + torch.tensor(expected, dtype=results.dtype), + ) diff --git a/tox.ini b/tox.ini index cd0a281d..d6ef0d9b 100644 --- a/tox.ini +++ b/tox.ini @@ -141,6 +141,7 @@ deps = torch=={env:METATOMIC_TESTS_TORCH_VERSION:2.10}.* numpy {env:METATOMIC_TESTS_NUMPY_VERSION_PIN} vesin + vesin-torch ase # for metatensor-lj-test setuptools-scm