Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/metatomic_torch/metatomic/torch/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,8 +935,7 @@ def _get_ase_input(
tensor.set_info("quantity", infos["quantity"])
tensor.set_info("unit", infos["unit"])

tensor = tensor.to(dtype=dtype, device=device)
return tensor
return tensor.to(dtype=dtype, device=device)


def _ase_to_torch_data(atoms, dtype, device):
Expand Down
379 changes: 379 additions & 0 deletions python/metatomic_torch/metatomic/torch/heat_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
from typing import Dict, List, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from vesin.metatomic import compute_requested_neighbors_from_options

from metatomic.torch import (
AtomisticModel,
ModelEvaluationOptions,
ModelOutput,
NeighborListOptions,
System,
)


def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
"""
Wrap positions into the periodic cell.
"""
fractional_positions = positions @ cell.inverse()
fractional_positions = fractional_positions - torch.floor(fractional_positions)
wrapped_positions = fractional_positions @ cell

return wrapped_positions


def _check_close_to_cell_boundary(
cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float
) -> torch.Tensor:
"""
Detect atoms that lie within a cutoff distance (in our context, the interaction
range of the model + the skin) from the periodic cell boundaries,
i.e. have interactions with atoms at the opposite end of the cell.
"""
inv_cell = cell.inverse()
recip = inv_cell.T
norms = torch.linalg.norm(recip, dim=1)
heights = 1.0 / norms
Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a strange way to get the cell vectors length. Are you sure it works even for very slanted triclinic cells?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we calculate the distance between the surfaces, for example, the first height is the distance between the surface spanded by the b vector and c vector and the opposite surface

if heights.min() < (cutoff + skin):
raise ValueError(
"Cell is too small compared to (cutoff + skin) = "
+ str(cutoff + skin)
+ ". "
"Ensure that all cell vectors are at least this length. Currently, the"
" minimum cell vector length is " + str(heights.min()) + "."
)

cutoff = cutoff + skin
normals = recip / norms[:, None]
norm_coords = positions @ normals.T
collisions = torch.hstack(
[norm_coords <= cutoff, norm_coords >= heights - cutoff],
).to(device=positions.device)

return collisions[
:, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi)
]


def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor:
"""
Convert boundary-collision flags into a boolean mask over all periodic image
displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi
boundaries, we need the replicas at (1, 0, 0), (0, -1, 0), (1, -1, 0) image cells.

collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi)

returns: [N, 3, 3, 3] boolean mask over image displacements in {0, +1, -1}^3
0: no replica needed along that axis
1: +1 replica needed along that axis (i.e., near low boundary, a replica is
placed just outside the high boundary)
2: -1 replica needed along that axis (i.e., near high boundary, a replica is
placed just outside the low boundary)
axis order: x, y, z
"""
origin = torch.full(
(len(collisions),), True, dtype=torch.bool, device=collisions.device
)
axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]])
ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]])
azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]])
# leverage broadcasting
outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :]
outs = torch.movedim(outs, -1, 0)
outs[:, 0, 0, 0] = False # not close to any boundary -> no replica needed
return outs.to(device=collisions.device)


def _generate_replica_atoms(
types: torch.Tensor,
positions: torch.Tensor,
cell: torch.Tensor,
replicas: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
For atoms near the low boundary (x_lo/y_lo/z_lo), generate their images shifted
by +1 cell vector (i.e., placed just outside the high boundary).
For atoms near the high boundary (x_hi/y_hi/z_hi), generate images shifted by −1
cell vector.
"""
replicas = torch.argwhere(replicas)
replica_idx = replicas[:, 0]
replica_offsets = torch.tensor(
[0, 1, -1], device=positions.device, dtype=positions.dtype
)[replicas[:, 1:]]
replica_positions = positions[replica_idx] + replica_offsets @ cell

return replica_idx, types[replica_idx], replica_positions


def _unfold_system(
metatomic_system: System, cutoff: float, skin: float = 0.5
) -> System:
"""
Unfold a periodic system by generating replica atoms for those near the cell
boundaries within the specified cutoff distance.
The unfolded system has no periodic boundary conditions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain in the docstring why we also need a skin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter is from Marcel's implementation, but more for the neighbor list caching but not a special design for the heat flux calculation, so I think I can remove this later

"""

if not metatomic_system.pbc.any():
raise ValueError("Unfolding systems is only supported for periodic systems.")
wrapped_positions = _wrap_positions(
metatomic_system.positions, metatomic_system.cell
)
collisions = _check_close_to_cell_boundary(
metatomic_system.cell, wrapped_positions, cutoff, skin
)
replicas = _collisions_to_replicas(collisions)
replica_idx, replica_types, replica_positions = _generate_replica_atoms(
metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas
)
unfolded_types = torch.cat(
[
metatomic_system.types,
replica_types,
]
)
unfolded_positions = torch.cat(
[
wrapped_positions,
replica_positions,
]
)
unfolded_idx = torch.cat(
[
torch.arange(len(metatomic_system.types), device=metatomic_system.device),
replica_idx,
]
)
unfolded_n_atoms = len(unfolded_types)
masses_block = metatomic_system.get_data("masses").block()
velocities_block = metatomic_system.get_data("velocities").block()
unfolded_masses = masses_block.values[unfolded_idx]
unfolded_velocities = velocities_block.values[unfolded_idx]
unfolded_masses_block = TensorBlock(
values=unfolded_masses,
samples=Labels(
["atoms"],
torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape(
-1, 1
),
),
components=masses_block.components,
properties=masses_block.properties,
)
unfolded_velocities_block = TensorBlock(
values=unfolded_velocities,
samples=Labels(
["atoms"],
torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape(
-1, 1
),
),
components=velocities_block.components,
properties=velocities_block.properties,
)
unfolded_system = System(
types=unfolded_types,
positions=unfolded_positions,
cell=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
dtype=unfolded_positions.dtype,
device=metatomic_system.device,
),
pbc=torch.tensor([False, False, False], device=metatomic_system.device),
)
unfolded_system.add_data(
"masses",
TensorMap(
Labels("_", torch.tensor([[0]], device=metatomic_system.device)),
[unfolded_masses_block],
),
)
unfolded_system.add_data(
"velocities",
TensorMap(
Labels("_", torch.tensor([[0]], device=metatomic_system.device)),
[unfolded_velocities_block],
),
)
return unfolded_system.to(metatomic_system.dtype, metatomic_system.device)


class HeatFluxWrapper(torch.nn.Module):
"""
A wrapper around an AtomisticModel that computes the heat flux of a system using the
unfolded system approach. The heat flux is computed using the atomic energies (eV),
positions(Å), masses(u), velocities(Å/fs), and the energy gradients.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the unit matter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm yes users don't need to know these, I think only need to tell them the unit of the output heat flux


The unfolded system is generated by creating replica atoms for those near the cell
boundaries within the interaction range of the model wrapped. The wrapper adds the
heat flux to the model's outputs under the key "extra::heat_flux".

For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux
for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.`
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also contain some information about how to then wrap this class back inside an AtomisticModel, in particular how to define capabilities/metadata.


def __init__(self, model: AtomisticModel, skin: float = 0.5):
"""
:param model: the :py:class:`AtomisticModel` to wrap, which should be able to
compute atomic energies and their gradients with respect to positions
:param skin: the skin parameter for unfolding the system. The wrapper will
generate replica atoms for those within (interaction_range + skin) distance from
the cell boundaries. A skin results in more replica atoms and thus higher
computational cost, but ensures that the heat flux is computed correctly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it ensure the heat flux is computed correctly? What kind of issue would one have with a skin too small? (this should be explained in the docstring for users)

"""
super().__init__()

assert isinstance(model, AtomisticModel)
self._model = model.module
self.skin = skin
self._interaction_range = model.capabilities().interaction_range

self._requested_neighbor_lists = model.requested_neighbor_lists()
self._requested_inputs = {
"masses": ModelOutput(quantity="mass", unit="u", per_atom=True),
"velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not match the positions units required by the underlying model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm not sure what happens if the underlying model already request these inputs but with a different unit. We should at least check & error.

}

hf_output = ModelOutput(
quantity="heat_flux",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not exist right now, but we should add it!

unit="",
explicit_gradients=[],
per_atom=False,
)
outputs = model.capabilities().outputs.copy()
outputs["extra::heat_flux"] = hf_output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have multiple energy variant in the underlying model, I think this should expose multiple heat_flux variants with the same names/descriptions.

if outputs["energy"].unit != "eV":
raise ValueError(
"HeatFluxWrapper can only be used with energy outputs in eV"
)
energies_output = ModelOutput(
quantity="energy", unit=outputs["energy"].unit, per_atom=True
)
self._unfolded_run_options = ModelEvaluationOptions(
length_unit=model.capabilities().length_unit,
outputs={"energy": energies_output},
selected_atoms=None,
)

@torch.jit.export
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@torch.jit.export

Not required, we will do the export on AtomisticModel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope if I remove this, the test fails

self = RecursiveScriptModule(
  original_name=HeatFluxWrapper
  (_model): RecursiveScriptModule(
    original_name=AtomisticModel
    (module): RecursiveScriptModule(original_name=LennardJonesPurePyTorch)
  )
), args = ()
kwargs = {'outputs': {'energy': <torch.ScriptObject object at 0x6123e80c6670>, 'energy_uncertainty': <torch.ScriptObject object...123e40184f0>}, 'selected_atoms': None, 'systems': [System with 8 atoms, periodic cell: [12, 0, 0, 0, 12, 0, 0, 0, 12]]}

    def _call_impl(self, *args, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
                or _global_backward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
>           return forward_call(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E             File "/home/qxu/repos/metatomic/.tox/torch-tests/lib/python3.13/site-packages/metatomic/torch/heat_flux.py", line 357, in forward
E                   for system in systems:
E                       system.positions.requires_grad_(True)
E                       heat_fluxes.append(self._calc_unfolded_heat_flux(system))
E                                          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E               
E                   samples = Labels(
E             File "/home/qxu/repos/metatomic/.tox/torch-tests/lib/python3.13/site-packages/metatomic/torch/heat_flux.py", line 277, in _calc_unfolded_heat_flux
E               def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor:
E                   n_atoms = len(system.positions)
E                   unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to(
E                                     ~~~~~~~~~~~~~~ <--- HERE
E                       system.device
E                   )
E             File "/home/qxu/repos/metatomic/.tox/torch-tests/lib/python3.13/site-packages/metatomic/torch/heat_flux.py", line 149, in _unfold_system
E               )
E               unfolded_n_atoms = len(unfolded_types)
E               masses_block = metatomic_system.get_data("masses").block()
E                              ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
E               velocities_block = metatomic_system.get_data("velocities").block()
E               unfolded_masses = masses_block.values[unfolded_idx]
E           RuntimeError: no data for 'masses' found in this system

/home/qxu/repos/metatomic/.tox/torch-tests/lib/python3.13/site-packages/torch/nn/modules/module.py:1787: RuntimeError

def requested_neighbor_lists(self) -> List[NeighborListOptions]:
return self._requested_neighbor_lists

@torch.jit.export
def requested_inputs(self) -> Dict[str, ModelOutput]:
return self._requested_inputs

def _barycenter_and_atomic_energies(self, system: System, n_atoms: int):
energy_block = self._model(
[system],
self._unfolded_run_options.outputs,
self._unfolded_run_options.selected_atoms,
)["energy"].block(0)
atom_indices = energy_block.samples.column("atom").to(torch.long)
sorted_order = torch.argsort(atom_indices)
atomic_e = energy_block.values.flatten()[sorted_order]

total_e = atomic_e[:n_atoms].sum()
r_aux = system.positions.detach()
barycenter = (atomic_e[:n_atoms, None] * r_aux[:n_atoms]).sum(dim=0)

return barycenter, atomic_e, total_e

def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor:
n_atoms = len(system.positions)
unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to(
system.device
)
compute_requested_neighbors_from_options(
[unfolded_system],
self.requested_neighbor_lists(),
self._unfolded_run_options.length_unit,
False,
)
velocities: torch.Tensor = (
unfolded_system.get_data("velocities").block().values.reshape(-1, 3)
)
masses: torch.Tensor = (
unfolded_system.get_data("masses").block().values.reshape(-1)
)
barycenter, atomic_e, total_e = self._barycenter_and_atomic_energies(
unfolded_system, n_atoms
)

term1 = torch.zeros(
(3), device=system.positions.device, dtype=system.positions.dtype
)
for i in range(3):
grad_i = torch.autograd.grad(
[barycenter[i]],
[unfolded_system.positions],
retain_graph=True,
create_graph=False,
)[0]
grad_i = torch.jit._unwrap_optional(grad_i)
term1[i] = (grad_i * velocities).sum()

go = torch.jit.annotate(
Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)]
)
grads = torch.autograd.grad(
[total_e],
[unfolded_system.positions],
grad_outputs=go,
)[0]
grads = torch.jit._unwrap_optional(grads)
term2 = (
unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True)
).sum(dim=0)

hf_pot = term1 - term2

hf_conv = (
(
atomic_e[:n_atoms]
+ 0.5
* masses[:n_atoms]
* torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2
* 103.6427 # u*A^2/fs^2 to eV
)[:, None]
* velocities[:n_atoms]
).sum(dim=0)

return hf_pot + hf_conv

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels],
) -> Dict[str, TensorMap]:
outputs_wo_heat_flux = outputs.copy()
del outputs_wo_heat_flux["extra::heat_flux"]
results = self._model(systems, outputs_wo_heat_flux, selected_atoms)

if "extra::heat_flux" not in outputs:
return results

device = systems[0].device
heat_fluxes: List[torch.Tensor] = []
for system in systems:
system.positions.requires_grad_(True)
heat_fluxes.append(self._calc_unfolded_heat_flux(system))

samples = Labels(
["system"], torch.arange(len(systems), device=device).reshape(-1, 1)
)

hf_block = TensorBlock(
values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device),
samples=samples,
components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))],
properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)),
)
results["extra::heat_flux"] = TensorMap(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a custom output here, it would be better to also add heat_flux as a standard output =)

Labels("_", torch.tensor([[0]], device=device)), [hf_block]
)
return results
Loading
Loading