-
Notifications
You must be signed in to change notification settings - Fork 7
Add a heat flux wrapper #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
80bfb0b
2decdf4
e4abd21
0954a05
3f7f07b
7dfb8bf
e57623f
79f3f78
e4093c2
155747a
ec04879
677cd19
2af67ec
9eebb91
1dcf3ed
f78cfa7
ab9e819
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,379 @@ | ||||
| from typing import Dict, List, Optional | ||||
|
|
||||
| import torch | ||||
| from metatensor.torch import Labels, TensorBlock, TensorMap | ||||
| from vesin.metatomic import compute_requested_neighbors_from_options | ||||
|
|
||||
| from metatomic.torch import ( | ||||
| AtomisticModel, | ||||
| ModelEvaluationOptions, | ||||
| ModelOutput, | ||||
| NeighborListOptions, | ||||
| System, | ||||
| ) | ||||
|
|
||||
|
|
||||
| def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: | ||||
| """ | ||||
| Wrap positions into the periodic cell. | ||||
| """ | ||||
| fractional_positions = positions @ cell.inverse() | ||||
| fractional_positions = fractional_positions - torch.floor(fractional_positions) | ||||
| wrapped_positions = fractional_positions @ cell | ||||
|
|
||||
| return wrapped_positions | ||||
|
|
||||
|
|
||||
| def _check_close_to_cell_boundary( | ||||
| cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float | ||||
| ) -> torch.Tensor: | ||||
| """ | ||||
| Detect atoms that lie within a cutoff distance (in our context, the interaction | ||||
| range of the model + the skin) from the periodic cell boundaries, | ||||
| i.e. have interactions with atoms at the opposite end of the cell. | ||||
| """ | ||||
| inv_cell = cell.inverse() | ||||
| recip = inv_cell.T | ||||
| norms = torch.linalg.norm(recip, dim=1) | ||||
| heights = 1.0 / norms | ||||
| if heights.min() < (cutoff + skin): | ||||
| raise ValueError( | ||||
| "Cell is too small compared to (cutoff + skin) = " | ||||
| + str(cutoff + skin) | ||||
| + ". " | ||||
| "Ensure that all cell vectors are at least this length. Currently, the" | ||||
| " minimum cell vector length is " + str(heights.min()) + "." | ||||
| ) | ||||
|
|
||||
| cutoff = cutoff + skin | ||||
| normals = recip / norms[:, None] | ||||
| norm_coords = positions @ normals.T | ||||
| collisions = torch.hstack( | ||||
| [norm_coords <= cutoff, norm_coords >= heights - cutoff], | ||||
| ).to(device=positions.device) | ||||
|
|
||||
| return collisions[ | ||||
| :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) | ||||
| ] | ||||
|
|
||||
|
|
||||
| def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: | ||||
| """ | ||||
| Convert boundary-collision flags into a boolean mask over all periodic image | ||||
| displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi | ||||
| boundaries, we need the replicas at (1, 0, 0), (0, -1, 0), (1, -1, 0) image cells. | ||||
|
|
||||
| collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) | ||||
|
|
||||
| returns: [N, 3, 3, 3] boolean mask over image displacements in {0, +1, -1}^3 | ||||
| 0: no replica needed along that axis | ||||
| 1: +1 replica needed along that axis (i.e., near low boundary, a replica is | ||||
| placed just outside the high boundary) | ||||
| 2: -1 replica needed along that axis (i.e., near high boundary, a replica is | ||||
| placed just outside the low boundary) | ||||
| axis order: x, y, z | ||||
| """ | ||||
| origin = torch.full( | ||||
| (len(collisions),), True, dtype=torch.bool, device=collisions.device | ||||
| ) | ||||
| axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]]) | ||||
| ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]]) | ||||
| azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]]) | ||||
| # leverage broadcasting | ||||
| outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] | ||||
| outs = torch.movedim(outs, -1, 0) | ||||
| outs[:, 0, 0, 0] = False # not close to any boundary -> no replica needed | ||||
| return outs.to(device=collisions.device) | ||||
|
|
||||
|
|
||||
| def _generate_replica_atoms( | ||||
| types: torch.Tensor, | ||||
| positions: torch.Tensor, | ||||
| cell: torch.Tensor, | ||||
| replicas: torch.Tensor, | ||||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||
| """ | ||||
| For atoms near the low boundary (x_lo/y_lo/z_lo), generate their images shifted | ||||
| by +1 cell vector (i.e., placed just outside the high boundary). | ||||
| For atoms near the high boundary (x_hi/y_hi/z_hi), generate images shifted by −1 | ||||
| cell vector. | ||||
| """ | ||||
| replicas = torch.argwhere(replicas) | ||||
| replica_idx = replicas[:, 0] | ||||
| replica_offsets = torch.tensor( | ||||
| [0, 1, -1], device=positions.device, dtype=positions.dtype | ||||
| )[replicas[:, 1:]] | ||||
| replica_positions = positions[replica_idx] + replica_offsets @ cell | ||||
|
|
||||
| return replica_idx, types[replica_idx], replica_positions | ||||
|
|
||||
|
|
||||
| def _unfold_system( | ||||
| metatomic_system: System, cutoff: float, skin: float = 0.5 | ||||
| ) -> System: | ||||
| """ | ||||
| Unfold a periodic system by generating replica atoms for those near the cell | ||||
| boundaries within the specified cutoff distance. | ||||
| The unfolded system has no periodic boundary conditions. | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain in the docstring why we also need a skin?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This parameter is from Marcel's implementation, but more for the neighbor list caching but not a special design for the heat flux calculation, so I think I can remove this later |
||||
| """ | ||||
|
|
||||
| if not metatomic_system.pbc.any(): | ||||
| raise ValueError("Unfolding systems is only supported for periodic systems.") | ||||
| wrapped_positions = _wrap_positions( | ||||
| metatomic_system.positions, metatomic_system.cell | ||||
GardevoirX marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
| ) | ||||
| collisions = _check_close_to_cell_boundary( | ||||
| metatomic_system.cell, wrapped_positions, cutoff, skin | ||||
| ) | ||||
| replicas = _collisions_to_replicas(collisions) | ||||
| replica_idx, replica_types, replica_positions = _generate_replica_atoms( | ||||
| metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas | ||||
| ) | ||||
| unfolded_types = torch.cat( | ||||
| [ | ||||
| metatomic_system.types, | ||||
| replica_types, | ||||
| ] | ||||
| ) | ||||
| unfolded_positions = torch.cat( | ||||
| [ | ||||
| wrapped_positions, | ||||
| replica_positions, | ||||
| ] | ||||
| ) | ||||
| unfolded_idx = torch.cat( | ||||
| [ | ||||
| torch.arange(len(metatomic_system.types), device=metatomic_system.device), | ||||
| replica_idx, | ||||
| ] | ||||
| ) | ||||
| unfolded_n_atoms = len(unfolded_types) | ||||
| masses_block = metatomic_system.get_data("masses").block() | ||||
| velocities_block = metatomic_system.get_data("velocities").block() | ||||
| unfolded_masses = masses_block.values[unfolded_idx] | ||||
| unfolded_velocities = velocities_block.values[unfolded_idx] | ||||
| unfolded_masses_block = TensorBlock( | ||||
| values=unfolded_masses, | ||||
| samples=Labels( | ||||
| ["atoms"], | ||||
| torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( | ||||
| -1, 1 | ||||
| ), | ||||
| ), | ||||
| components=masses_block.components, | ||||
| properties=masses_block.properties, | ||||
| ) | ||||
| unfolded_velocities_block = TensorBlock( | ||||
| values=unfolded_velocities, | ||||
| samples=Labels( | ||||
| ["atoms"], | ||||
| torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( | ||||
| -1, 1 | ||||
| ), | ||||
| ), | ||||
| components=velocities_block.components, | ||||
| properties=velocities_block.properties, | ||||
| ) | ||||
| unfolded_system = System( | ||||
| types=unfolded_types, | ||||
| positions=unfolded_positions, | ||||
| cell=torch.tensor( | ||||
| [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], | ||||
| dtype=unfolded_positions.dtype, | ||||
| device=metatomic_system.device, | ||||
| ), | ||||
| pbc=torch.tensor([False, False, False], device=metatomic_system.device), | ||||
| ) | ||||
| unfolded_system.add_data( | ||||
| "masses", | ||||
| TensorMap( | ||||
| Labels("_", torch.tensor([[0]], device=metatomic_system.device)), | ||||
| [unfolded_masses_block], | ||||
| ), | ||||
| ) | ||||
| unfolded_system.add_data( | ||||
| "velocities", | ||||
| TensorMap( | ||||
| Labels("_", torch.tensor([[0]], device=metatomic_system.device)), | ||||
| [unfolded_velocities_block], | ||||
| ), | ||||
| ) | ||||
| return unfolded_system.to(metatomic_system.dtype, metatomic_system.device) | ||||
|
|
||||
|
|
||||
| class HeatFluxWrapper(torch.nn.Module): | ||||
| """ | ||||
| A wrapper around an AtomisticModel that computes the heat flux of a system using the | ||||
| unfolded system approach. The heat flux is computed using the atomic energies (eV), | ||||
| positions(Å), masses(u), velocities(Å/fs), and the energy gradients. | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do the unit matter here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm yes users don't need to know these, I think only need to tell them the unit of the output heat flux |
||||
|
|
||||
| The unfolded system is generated by creating replica atoms for those near the cell | ||||
| boundaries within the interaction range of the model wrapped. The wrapper adds the | ||||
| heat flux to the model's outputs under the key "extra::heat_flux". | ||||
|
|
||||
| For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux | ||||
| for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` | ||||
| """ | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should also contain some information about how to then wrap this class back inside an AtomisticModel, in particular how to define capabilities/metadata. |
||||
|
|
||||
| def __init__(self, model: AtomisticModel, skin: float = 0.5): | ||||
| """ | ||||
| :param model: the :py:class:`AtomisticModel` to wrap, which should be able to | ||||
| compute atomic energies and their gradients with respect to positions | ||||
| :param skin: the skin parameter for unfolding the system. The wrapper will | ||||
| generate replica atoms for those within (interaction_range + skin) distance from | ||||
| the cell boundaries. A skin results in more replica atoms and thus higher | ||||
| computational cost, but ensures that the heat flux is computed correctly. | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it ensure the heat flux is computed correctly? What kind of issue would one have with a skin too small? (this should be explained in the docstring for users) |
||||
| """ | ||||
| super().__init__() | ||||
|
|
||||
| assert isinstance(model, AtomisticModel) | ||||
| self._model = model.module | ||||
| self.skin = skin | ||||
| self._interaction_range = model.capabilities().interaction_range | ||||
|
|
||||
| self._requested_neighbor_lists = model.requested_neighbor_lists() | ||||
| self._requested_inputs = { | ||||
| "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), | ||||
| "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might not match the positions units required by the underlying model
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I'm not sure what happens if the underlying model already request these inputs but with a different unit. We should at least check & error. |
||||
| } | ||||
|
|
||||
| hf_output = ModelOutput( | ||||
| quantity="heat_flux", | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not exist right now, but we should add it! |
||||
| unit="", | ||||
| explicit_gradients=[], | ||||
| per_atom=False, | ||||
| ) | ||||
| outputs = model.capabilities().outputs.copy() | ||||
| outputs["extra::heat_flux"] = hf_output | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have multiple energy variant in the underlying model, I think this should expose multiple heat_flux variants with the same names/descriptions. |
||||
| if outputs["energy"].unit != "eV": | ||||
| raise ValueError( | ||||
| "HeatFluxWrapper can only be used with energy outputs in eV" | ||||
| ) | ||||
| energies_output = ModelOutput( | ||||
| quantity="energy", unit=outputs["energy"].unit, per_atom=True | ||||
| ) | ||||
| self._unfolded_run_options = ModelEvaluationOptions( | ||||
| length_unit=model.capabilities().length_unit, | ||||
| outputs={"energy": energies_output}, | ||||
| selected_atoms=None, | ||||
| ) | ||||
|
|
||||
| @torch.jit.export | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Not required, we will do the export on AtomisticModel
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope if I remove this, the test fails |
||||
| def requested_neighbor_lists(self) -> List[NeighborListOptions]: | ||||
| return self._requested_neighbor_lists | ||||
|
|
||||
| @torch.jit.export | ||||
| def requested_inputs(self) -> Dict[str, ModelOutput]: | ||||
| return self._requested_inputs | ||||
|
|
||||
| def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): | ||||
| energy_block = self._model( | ||||
| [system], | ||||
| self._unfolded_run_options.outputs, | ||||
| self._unfolded_run_options.selected_atoms, | ||||
| )["energy"].block(0) | ||||
| atom_indices = energy_block.samples.column("atom").to(torch.long) | ||||
| sorted_order = torch.argsort(atom_indices) | ||||
| atomic_e = energy_block.values.flatten()[sorted_order] | ||||
|
|
||||
| total_e = atomic_e[:n_atoms].sum() | ||||
| r_aux = system.positions.detach() | ||||
| barycenter = (atomic_e[:n_atoms, None] * r_aux[:n_atoms]).sum(dim=0) | ||||
|
|
||||
| return barycenter, atomic_e, total_e | ||||
|
|
||||
| def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: | ||||
| n_atoms = len(system.positions) | ||||
| unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to( | ||||
| system.device | ||||
| ) | ||||
| compute_requested_neighbors_from_options( | ||||
| [unfolded_system], | ||||
| self.requested_neighbor_lists(), | ||||
| self._unfolded_run_options.length_unit, | ||||
| False, | ||||
| ) | ||||
| velocities: torch.Tensor = ( | ||||
| unfolded_system.get_data("velocities").block().values.reshape(-1, 3) | ||||
| ) | ||||
| masses: torch.Tensor = ( | ||||
| unfolded_system.get_data("masses").block().values.reshape(-1) | ||||
| ) | ||||
| barycenter, atomic_e, total_e = self._barycenter_and_atomic_energies( | ||||
| unfolded_system, n_atoms | ||||
| ) | ||||
|
|
||||
| term1 = torch.zeros( | ||||
| (3), device=system.positions.device, dtype=system.positions.dtype | ||||
| ) | ||||
| for i in range(3): | ||||
| grad_i = torch.autograd.grad( | ||||
| [barycenter[i]], | ||||
| [unfolded_system.positions], | ||||
| retain_graph=True, | ||||
| create_graph=False, | ||||
| )[0] | ||||
| grad_i = torch.jit._unwrap_optional(grad_i) | ||||
| term1[i] = (grad_i * velocities).sum() | ||||
|
|
||||
| go = torch.jit.annotate( | ||||
| Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] | ||||
| ) | ||||
| grads = torch.autograd.grad( | ||||
| [total_e], | ||||
| [unfolded_system.positions], | ||||
| grad_outputs=go, | ||||
| )[0] | ||||
| grads = torch.jit._unwrap_optional(grads) | ||||
| term2 = ( | ||||
| unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True) | ||||
| ).sum(dim=0) | ||||
|
|
||||
| hf_pot = term1 - term2 | ||||
|
|
||||
| hf_conv = ( | ||||
| ( | ||||
| atomic_e[:n_atoms] | ||||
| + 0.5 | ||||
| * masses[:n_atoms] | ||||
| * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 | ||||
| * 103.6427 # u*A^2/fs^2 to eV | ||||
| )[:, None] | ||||
| * velocities[:n_atoms] | ||||
| ).sum(dim=0) | ||||
|
|
||||
| return hf_pot + hf_conv | ||||
|
|
||||
| def forward( | ||||
| self, | ||||
| systems: List[System], | ||||
| outputs: Dict[str, ModelOutput], | ||||
| selected_atoms: Optional[Labels], | ||||
| ) -> Dict[str, TensorMap]: | ||||
| outputs_wo_heat_flux = outputs.copy() | ||||
| del outputs_wo_heat_flux["extra::heat_flux"] | ||||
| results = self._model(systems, outputs_wo_heat_flux, selected_atoms) | ||||
|
|
||||
| if "extra::heat_flux" not in outputs: | ||||
| return results | ||||
|
|
||||
| device = systems[0].device | ||||
| heat_fluxes: List[torch.Tensor] = [] | ||||
| for system in systems: | ||||
| system.positions.requires_grad_(True) | ||||
| heat_fluxes.append(self._calc_unfolded_heat_flux(system)) | ||||
|
|
||||
| samples = Labels( | ||||
| ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) | ||||
| ) | ||||
|
|
||||
| hf_block = TensorBlock( | ||||
| values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), | ||||
| samples=samples, | ||||
| components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], | ||||
| properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), | ||||
| ) | ||||
| results["extra::heat_flux"] = TensorMap( | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of having a custom output here, it would be better to also add heat_flux as a standard output =) |
||||
| Labels("_", torch.tensor([[0]], device=device)), [hf_block] | ||||
| ) | ||||
| return results | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a strange way to get the cell vectors length. Are you sure it works even for very slanted triclinic cells?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we calculate the distance between the surfaces, for example, the first height is the distance between the surface spanded by the b vector and c vector and the opposite surface