diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index 8855e1d6..4be82aec 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -70,5 +70,10 @@ def __getattr__(name): import metatomic.torch.ase_calculator return metatomic.torch.ase_calculator + + elif name == "rotational_utils": + import metatomic.torch.rotational_utils + + return metatomic.torch.rotational_utils else: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/metatomic_torch/metatomic/torch/rotational_utils.py b/python/metatomic_torch/metatomic/torch/rotational_utils.py new file mode 100644 index 00000000..a2a56132 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/rotational_utils.py @@ -0,0 +1,788 @@ +""" +Utilities for diagnosing rotational equivariance of models and for enforcing +rotational symmetry in data augmentation and model evaluation. +""" + +import warnings +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import TensorMap +from metatrain.utils.augmentation import ( + _apply_augmentations, + _complex_to_real_spherical_harmonics_transform, + _scipy_quaternion_to_quaternionic, +) + +import metatomic.torch # noqa: F401 +from metatomic.torch import ModelEvaluationOptions, System, register_autograd_neighbors +from metatomic.torch.model import AtomisticModel + + +try: + from scipy.spatial.transform import Rotation # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for Lebedev quadrature. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + from scipy.integrate import lebedev_rule + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (n_rotations,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (N,) + + return alpha, beta, gamma, w_so3 + + +def _rotations_from_angles(alpha, beta, gamma): + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def evaluate_model_on_quadrature(model, systems, L_max: int, device="cpu"): + pass + + +############ + + +def _extract_euler_zyz( + R: torch.Tensor, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract Z-Y-Z Euler angles (alpha, beta, gamma) from rotation matrices, with + explicit handling of the gimbal-lock cases (beta≈0 and beta≈pi). + TODO: This function is extremely sensitive to eps and will be modified. + Parameters + ---------- + R : np.ndarray + Rotation matrices with arbitrary batch shape `(..., 3, 3)`. + eps : float + Tolerance used to detect gimbal lock via `sin(beta) < eps`. + + Returns + ------- + (alphas, betas, gammas) : Tuple[np.ndarray, np.ndarray, np.ndarray] + Each with the same batch shape as `R[..., 0, 0]` (i.e., `R.shape[:-2]`). + + Notes + ----- + Conventions: + - Base convention is Z-Y-Z (Rz(alpha) Ry(beta) Rz(gamma)). + - For beta≈0: set beta=0, gamma=0, alpha=atan2(R[1,0], R[0,0]). + - For beta≈pi: set beta=pi, alpha=0, gamma=atan2(R[1,0], -R[0,0]). + These conventions ensure a deterministic inverse where the standard formulas + are ill-conditioned. + """ + # Accept any batch shape. Flatten to (N, 3, 3) for clarity, then unflatten. + batch_shape = R.shape[:-2] + R_flat = R.reshape(-1, 3, 3) + + # Read commonly-used entries with explicit names for readability + R00 = R_flat[:, 0, 0] + # R01 = R_flat[:, 0, 1] # unused + R02 = R_flat[:, 0, 2] + R10 = R_flat[:, 1, 0] + # R11 = R_flat[:, 1, 1] # unused + R12 = R_flat[:, 1, 2] + R20 = R_flat[:, 2, 0] + R21 = R_flat[:, 2, 1] + R22 = R_flat[:, 2, 2] + + # Default (non-singular) extraction + zz = torch.clip(R22, -1.0, 1.0) + betas = torch.arccos(zz) + + # For Z-Y-Z, standard formulas away from the singular set + alphas = torch.arctan2(R12, R02) + gammas = torch.arctan2(R21, -R20) + + # Normalize into [0, 2π) + two_pi = 2.0 * torch.pi + alphas = torch.remainder(alphas, two_pi) + gammas = torch.remainder(gammas, two_pi) + + # Gimbal-lock detection via sin(beta) + sinb = torch.sin(betas) + near = torch.abs(sinb) < eps + if torch.any(near): + # Split the two singular bands using zz = cos(beta) + near_zero = near & (zz > 0) # beta≈0 + near_pi = near & (zz < 0) # beta≈pi + + if torch.any(near_zero): + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from 2x2 + # block. + betas[near_zero] = 0.0 + gammas[near_zero] = 0.0 + alphas[near_zero] = torch.arctan2(R10[near_zero], R00[near_zero]) + alphas[near_zero] = torch.remainder(alphas[near_zero], two_pi) + + if torch.any(near_pi): + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on + # R00. + betas[near_pi] = torch.pi + alphas[near_pi] = 0.0 + gammas[near_pi] = torch.arctan2(R10[near_pi], -R00[near_pi]) + gammas[near_pi] = torch.remainder(gammas[near_pi], two_pi) + + # Unflatten back to the original batch shape + alphas = alphas.reshape(batch_shape) + betas = betas.reshape(batch_shape) + gammas = gammas.reshape(batch_shape) + return alphas, betas, gammas + + +def get_so3_character( + alphas: torch.Tensor, + betas: torch.Tensor, + gammas: torch.Tensor, + o3_lambda: int, + tol: float = 1e-7, +) -> torch.Tensor: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over SO(3). + + Uses a small-angle Taylor expansion for χ_ℓ(ω) = sin((2ℓ+1)t)/sin(t) with t = ω/2 + when |t| is very small, and a guarded ratio otherwise. + """ + # Compute half-angle t = ω/2 via Z–Y–Z relation: cos t = cos(β/2) cos((α+γ)/2) + cos_t = torch.cos(betas / 2.0) * torch.cos((alphas + gammas) / 2.0) + cos_t = torch.clip(cos_t, -1.0, 1.0) + t = torch.arccos(cos_t) + + # Output array + chi = torch.empty_like(t) + + # Parameters for χ + L = o3_lambda + a = 2 * L + 1 + ll1 = L * (L + 1) + + small = torch.abs(t) < tol + if torch.any(small): + # Series up to t^4: χ ≈ a [1 - (2/3) ℓ(ℓ+1) t^2 + (1/45) ℓ(ℓ+1)(3ℓ^2+3ℓ-1) t^4] + ts = t[small] + t2 = ts * ts + coeff4 = ll1 * (3 * L * L + 3 * L - 1) + chi[small] = a * ( + 1.0 - (2.0 / 3.0) * ll1 * t2 + (1.0 / 45.0) * coeff4 * t2 * t2 + ) + + # Large-angle (or not-so-small) branch: safe ratio with guard + large = ~small + if torch.any(large): + tl = t[large] + sin_t = torch.sin(tl) + numer = torch.sin(a * tl) + mask = torch.abs(sin_t) >= tol + out = torch.empty_like(tl) + torch.div(numer, sin_t, out=out) # TODO figure out with numpy divide + out[~mask] = a # exact limit as t -> 0 + chi[large] = out + + return chi + + +def get_so3_characters_dict( + alphas: torch.Tensor, betas: torch.Tensor, gammas: torch.Tensor, o3_lambda_max: int +) -> Dict[int, torch.Tensor]: + """ + Returns a dictionary of the SO(3) characters for all o3_lambda in [0, + o3_lambda_max]. + """ + characters = {} + for o3_lambda in range(o3_lambda_max + 1): + characters[o3_lambda] = get_so3_character(alphas, betas, gammas, o3_lambda) + return characters + + +def get_pso3_characters_dict( + so3_character: Dict[int, torch.Tensor], o3_lambda_max: int +) -> Dict[Tuple[int, int], torch.Tensor]: + """ + Returns a dictionary of the P⋅SO(3) characters for all (o3_lambda, o3_sigma) pairs + with o3_lambda in [0, o3_lambda_max] and o3_sigma in {-1, +1}. + Requires a pre-computed dictionary of SO(3) characters. + """ + characters = {} + for o3_lambda in range(o3_lambda_max + 1): + for o3_sigma in [-1, +1]: + characters[(o3_lambda, o3_sigma)] = ( + o3_sigma * ((-1) ** o3_lambda) * so3_character[o3_lambda] + ) + return characters + + +############ + + +class O3Sampler: + """ + Compute model predictions on a quadrature over the O(3) group. + + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch. + """ + + def __init__(self, quad_l_max: int, project_l_max: int, batch_size: int = 1): + try: + from scipy.spatial.transform import Rotation # noqa: F401 + except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e + + self.quad_l_max = quad_l_max + """Maximum spherical harmonic degree for quadrature.""" + + self.project_l_max = project_l_max + """Maximum spherical harmonic degree to project onto.""" + if self.project_l_max + 2 > self.quad_l_max: + warnings.warn( + ( + f"Projecting up to l={self.project_l_max} with quadrature up " + f"to l={self.quad_l_max} may be inaccurate." + ), + stacklevel=2, + ) + + # Get the quadrature + self.lebedev_order: int + """Number of Lebedev nodes on the unit sphere.""" + + self.n_inplane_rotations: int + """Number of in-plane rotations per Lebedev node.""" + self.lebedev_order, self.n_inplane_rotations = _choose_quadrature( + self.quad_l_max + ) + + self.w_so3: torch.Tensor + """Weights associated to each rotation in the SO(3) Haar measure.""" + + alpha, beta, gamma, self.w_so3 = get_euler_angles_quadrature( + self.lebedev_order, self.n_inplane_rotations + ) + self.w_so3 = torch.from_numpy(self.w_so3) + + # For active rotation of systems + self.R_so3 = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ) + """Rotation matrices.""" + + self.n_rotations = self.R_so3.size(0) + + # For inverse rotation of tensors + R_pso3 = _rotations_from_angles(np.pi - alpha, beta, np.pi - gamma) + self.wigner_D: Dict[int, torch.Tensor] = _compute_wigner_D_matrices( + self.project_l_max, R_pso3 + ) + """Dict mapping l to (N, (2l+1), (2l+1)) torch.Tensor of Wigner D matrices.""" + + self.R_pso3 = torch.from_numpy(R_pso3.as_matrix()) + """Inverse rotation matrices.""" + + self.batch_size = batch_size + + _external_product_euler_angles = _extract_euler_zyz( + (self.R_so3[:, None, :, :] @ self.R_pso3[None, :, :, :]), eps=1e-6 + ) + self.so3_characters = get_so3_characters_dict( + *_external_product_euler_angles, self.project_l_max + ) + self.pso3_characters = get_pso3_characters_dict( + self.so3_characters, self.project_l_max + ) + + def evaluate( + self, + model: AtomisticModel, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ): + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param model: atomistic model to evaluate + :param device: device to use for computation + :return: list of list of model outputs, shape (len(systems), N) + where N is the number of quadrature points + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] + for name in options.outputs.keys() + } + backtransformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] + for name in options.outputs.keys() + } + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs = [] + for batch in range(0, len(self.R_so3), self.batch_size): + transformed_systems = [ + _transform_system( + system, inversion * R.to(device=device, dtype=dtype) + ) + for R in self.R_so3[batch : batch + self.batch_size] + ] + outputs = model( + transformed_systems, + options=options, + check_consistency=check_consistency, + ) + rotation_outputs.append(outputs) + + for name in transformed_outputs: + tensor = mts.join( + [r[name] for r in rotation_outputs], + "samples", + remove_tensor_name=True, + ) + transformed_outputs[name][i_sys][inversion] = mts.rename_dimension( + tensor, "samples", "tensor", "o3_sample" + ) + + n_rot = self.R_so3.size(0) + for name in transformed_outputs: + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + tensor = transformed_outputs[name][i_sys][inversion] + _, backtransformed, _ = _apply_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + self.R_pso3.to(device=device, dtype=dtype) * inversion + ).unbind(0) + ), + self.wigner_D, + ) + backtransformed_outputs[name][i_sys][inversion] = backtransformed[ + name + ] + + return transformed_outputs, backtransformed_outputs + + +class TokenProjector(torch.nn.Module): + """ + Wrap an atomistic model to project its predictions onto spherical sectors. + + :param model: atomistic model to wrap + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch + """ + + def __init__( + self, + model: AtomisticModel, + quad_l_max: int, + project_l_max: int, + batch_size: Optional[int] = None, + ) -> None: + super().__init__() + self.model = model + """The underlying atomistic model.""" + self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) + """The projector onto spherical sectors.""" + + def forward( + self, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ) -> torch.Tensor: + """ + :param systems: list of systems to evaluate + :param options: model evaluation options + :param check_consistency: whether to check model consistency + :return: TODO + """ + + transformed_outputs, _ = self.o3_sampler.evaluate( + self.model, systems, options, check_consistency + ) + + # TODO do projection operations + pass + + +class SymmetrizedAtomisticModel(torch.nn.Module): + """ + Wrap an atomistic model to symmetrize its predictions over a quadrature and compute + O(3) averages, variances, and equivariance score. + + :param model: atomistic model to wrap + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch + """ + + def __init__( + self, + model: AtomisticModel, + quad_l_max: int, + project_l_max: int, + batch_size: Optional[int] = None, + ): + super().__init__("SymmetrizedAtomisticModel") + self.model = model + """The underlying atomistic model.""" + self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) + """The projector onto spherical sectors.""" + + def forward( + self, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ) -> torch.Tensor: + """ + :param systems: list of systems to evaluate + :param options: model evaluation options + :param check_consistency: whether to check model consistency + :return: + """ + + transformed_outputs, _ = self.o3_sampler.evaluate( + systems, self.model, options, check_consistency + ) + + return compute_projections( + self.o3_sampler.project_l_max, + systems, + transformed_outputs, + self.o3_sampler.w_so3, + self.o3_sampler.so3_characters, + self.o3_sampler.pso3_characters, + ) + + +def _compute_wigner_D_matrices( + l_max: int, + rotations: List["Rotation"], + complex_to_real: Optional[np.ndarray] = None, +) -> dict: + """ + Compute Wigner D matrices for all l <= project_l_max. + + :param l_max: maximum spherical harmonic degree + :param rotations: list of scipy Rotation objects + :param complex_to_real: optional dict mapping l to (2l+1, (2l+1)) array to convert + complex spherical harmonics to real spherical harmonics + :return: dict mapping l to (N, (2l+1), (2l+1)) array of Wigner D matrices + """ + + try: + import spherical + except ImportError as e: + # quaternionic (used below) is a dependency of spherical + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `spherical` package with `pip install spherical`." + ) from e + + wigner = spherical.Wigner(l_max) + scipy_quaternions = [r.as_quat() for r in rotations] + quaternionic_quaternions = [ + _scipy_quaternion_to_quaternionic(q) for q in scipy_quaternions + ] + wigner_D_matrices_complex = [wigner.D(q) for q in quaternionic_quaternions] + + if complex_to_real is None: + complex_to_real = { + ell: _complex_to_real_spherical_harmonics_transform(ell) + for ell in range(l_max + 1) + } + + wigner_D_matrices = {} + for ell in range(l_max + 1): + U = complex_to_real[ell] + wigner_D_matrices_l = [] + for wigner_D_matrix_complex in wigner_D_matrices_complex: + wigner_D_matrix = np.zeros((2 * ell + 1, 2 * ell + 1), dtype=np.complex128) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + wigner_D_matrix[m + ell, mp + ell] = ( + wigner_D_matrix_complex[wigner.Dindex(ell, m, mp)] + ).conj() + + wigner_D_matrix = U.conj() @ wigner_D_matrix @ U.T + assert np.allclose(wigner_D_matrix.imag, 0.0) + wigner_D_matrix = wigner_D_matrix.real + wigner_D_matrices_l.append(torch.from_numpy(wigner_D_matrix)) + wigner_D_matrices[ell] = wigner_D_matrices_l + + return wigner_D_matrices + + +# O3-integrals utilities + + +def compute_projections( + max_l: int, + systems: List[System], + transformed_outputs: Dict[str, List[TensorMap]], + weights: torch.Tensor, + so3_characters: Dict[int, torch.Tensor], + pso3_characters: Dict[Tuple[int, int], torch.Tensor], +) -> Tuple[ + Dict[str, List[Dict[int, TensorMap]]], + Dict[str, List[Dict[Tuple[int, int], TensorMap]]], + Dict[str, List[Dict[Tuple[int, int], TensorMap]]], +]: + """ + + TODO docstring, check type annotations + + - Take model outputs on a quadrature + - Manipulate dimensions + - Compute some integrals + - Return projections + + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + weights = weights.to(device, dtype) + so3_characters = {k: v.to(device, dtype) for k, v in so3_characters.items()} + pso3_characters = {k: v.to(device, dtype) for k, v in pso3_characters.items()} + + n_rotations = len(weights) + norms = {} + convolution_integrals = {} + normalized_convolution_integrals = {} + # Loop over targets + for name, transformed_output in transformed_outputs.items(): + norms[name] = [] + convolution_integrals[name] = [] + normalized_convolution_integrals[name] = [] + for o3_output_for_system in transformed_output: + proper = o3_output_for_system[1] + improper = o3_output_for_system[-1] + + # Weighting the tensors + broadcasted_w = ( + weights[proper[0].samples.column("o3_sample")] / 16 / torch.pi**2 + ) + proper_weighted = proper.copy() + improper_weighted = improper.copy() + for k in proper_weighted.keys: + proper_block = proper_weighted[k] + improper_block = improper_weighted[k] + proper_block.values[:] *= broadcasted_w.view( + -1, *[1] * (proper_block.values.ndim - 1) + ) + improper_block.values[:] *= broadcasted_w.view( + -1, *[1] * (improper_block.values.ndim - 1) + ) + + # Compute norms + proper_norm = mts.multiply(proper, proper_weighted) + improper_norm = mts.multiply(improper, improper_weighted) + norm = mts.add(proper_norm, improper_norm) + norm = mts.sum_over_samples(norm, "o3_sample") + norms[name].append(norm) + + # Compute convolution integrals + convolution_integral = {} + normalized_convolution_integral = {} + for ell in range(max_l + 1): + so3_char = so3_characters[ell] + for sigma in [-1, 1]: + pso3_char = pso3_characters[(ell, sigma)] + + integral_blocks = [] + for k in proper.keys: + proper_block = proper[k].values.reshape( + -1, n_rotations, *proper[k].shape[1:] + ) + improper_block = improper[k].values.reshape( + -1, n_rotations, *improper[k].shape[1:] + ) + integral_values = ( + ( + 0.25 + * torch.einsum( + "ij...,nij...->n...", + so3_char, + proper_block[:, :, None, ...] + * proper_block[:, None, :, ...] + + improper_block[:, :, None, ...] + * improper_block[:, None, :, ...], + ) + + 0.5 + * torch.einsum( + "ij...,nij...->n...", + pso3_char, + proper_block[:, :, None, ...] + * improper_block[:, None, :, ...], + ) + ) + * (2 * ell + 1) + / (8 * torch.pi**2) ** 2 + ) + integral_blocks.append( + mts.TensorBlock( + samples=norm[k].samples, + components=norm[k].components, + properties=norm[k].properties, + values=integral_values, + ) + ) + convolution_integral[(ell, sigma)] = mts.TensorMap( + keys=norm.keys, blocks=integral_blocks + ) + normalized_convolution_integral[(ell, sigma)] = mts.divide( + convolution_integral[(ell, sigma)], norm + ) + convolution_integrals[name].append(convolution_integral) + normalized_convolution_integrals[name].append( + normalized_convolution_integral + ) + + return norms, convolution_integrals, normalized_convolution_integrals + + +# IO utilities + + +def norms_to(norms, dtype, device): + """Moves the TensorMap of norms to dtype and device""" + + norms_to = {} + for output_name in norms.keys(): + quantity_list = [] + for quantity in norms[output_name]: + quantity_list.append(quantity.to(dtype=dtype, device=device)) + norms_to[output_name] = quantity_list + + return norms_to + + +def integrals_to(integral, dtype, device): + """Moves the TensorMap of integrals to dtype and device""" + + integral_to = {} + for output_name in integral.keys(): + quantity_list = [] + for quantity_dict in integral[output_name]: + quantity_dict_to = {} + for key, quantity in quantity_dict.items(): + quantity_dict_to[key] = quantity.to(dtype=dtype, device=device) + quantity_list.append(quantity_dict_to) + integral_to[output_name] = quantity_list + + return integral_to diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py new file mode 100644 index 00000000..911a82a2 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -0,0 +1,1736 @@ +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import metatensor.torch as mts + + +if TYPE_CHECKING: + + class TensorBlock: ... + + class System: ... + + class TensorMap: ... + + class ModelOutput: ... + + class Labels: ... + + class ModelInterface: ... + +else: + from metatensor.torch import Labels, TensorBlock, TensorMap + + from metatomic.torch import ModelOutput, System + +import numpy as np +import torch +from metatrain.utils.augmentation import _apply_augmentations + +from metatomic.torch import ModelInterface, register_autograd_neighbors + + +try: + from scipy.integrate import lebedev_rule # noqa: F401 + from scipy.spatial.transform import Rotation # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e +try: + import spherical # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `spherical` package with `pip install spherical`." + ) from e +try: + import quaternionic # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `quaternionic` package with `pip install quaternionic`." + ) from e + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for a Lebedev quadrature combined with in-plane + rotations for SO(3) integration. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (K,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (M*K,) + + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + return A, B, G, w_so3 + + +def _rotations_from_angles( + alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> Rotation: + """ + Compose rotations from ZYZ Euler angles. + + :param alpha: array of alpha angles (M,) + :param beta: array of beta angles (M,) + :param gamma: array of gamma angles (K,) + :return: Rotation object containing all (M*K,) rotations + """ + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", alpha.reshape(-1, 1)) + * Rotation.from_euler("y", beta.reshape(-1, 1)) + * Rotation.from_euler("z", gamma.reshape(-1, 1)) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def _complex_to_real_spherical_harmonics_transform(ell: int) -> np.ndarray: + """ + Generate the transformation matrix from complex spherical harmonics + to real spherical harmonics for a given l. + Returns a transformation matrix of shape ((2l+1), (2l+1)). + """ + if ell < 0 or not isinstance(ell, int): + raise ValueError("l must be a non-negative integer.") + + # The size of the transformation matrix is (2l+1) x (2l+1) + size = 2 * ell + 1 + T = np.zeros((size, size), dtype=complex) + + for m in range(-ell, ell + 1): + m_index = m + ell # Index in the matrix + if m > 0: + # Real part of Y_{l}^{m} + T[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m + T[m_index, ell - m] = 1 / np.sqrt(2) + elif m < 0: + # Imaginary part of Y_{l}^{|m|} + T[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m + T[m_index, ell - abs(m)] = 1j / np.sqrt(2) + else: # m == 0 + # Y_{l}^{0} remains unchanged + T[m_index, ell] = 1 + + # Return the transformation matrix to convert complex to real spherical harmonics + return T + + +def _compute_real_wigner_matrices( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], # alpha, beta, gamma +) -> Dict[int, np.ndarray]: + wigner = spherical.Wigner(o3_lambda_max) + R = quaternionic.array.from_euler_angles(*angles) + D = wigner.D(R) + wigner_D_matrices = {} + for ell in range(o3_lambda_max + 1): + wigner_D_matrices[ell] = np.zeros( + angles[0].shape + (2 * ell + 1, 2 * ell + 1), dtype=np.complex128 + ) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + # There is an unexplained conjugation factor in the definition given in + # the quaternionic library. + wigner_D_matrices[ell][..., mp + ell, m + ell] = ( + D[..., wigner.Dindex(ell, mp, m)] + ).conj() + U = _complex_to_real_spherical_harmonics_transform(ell) + wigner_D_matrices[ell] = np.einsum( + "ij,...jk,kl->...il", U.conj(), wigner_D_matrices[ell], U.T + ) + assert np.allclose(wigner_D_matrices[ell].imag, 0) + wigner_D_matrices[ell] = torch.from_numpy(wigner_D_matrices[ell].real) + + return wigner_D_matrices + + +def _angles_from_rotations( + R: np.ndarray, + eps: float = 1e-6, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Extract Z-Y-Z Euler angles (alpha, beta, gamma) from rotation matrices, with + explicit handling of the gimbal-lock cases (beta≈0 and beta≈pi). + TODO: This function is extremely sensitive to eps and will be modified. + Parameters + ---------- + R : np.ndarray + Rotation matrices with arbitrary batch shape `(..., 3, 3)`. + eps : float + Tolerance used to detect gimbal lock via `sin(beta) < eps`. + + Returns + ------- + (alphas, betas, gammas) : Tuple[np.ndarray, np.ndarray, np.ndarray] + Each with the same batch shape as `R[..., 0, 0]` (i.e., `R.shape[:-2]`). + + Notes + ----- + Conventions: + - Base convention is Z-Y-Z (Rz(alpha) Ry(beta) Rz(gamma)). + - For beta≈0: set beta=0, gamma=0, alpha=atan2(R[1,0], R[0,0]). + - For beta≈pi: set beta=pi, alpha=0, gamma=atan2(R[1,0], -R[0,0]). + These conventions ensure a deterministic inverse where the standard formulas + are ill-conditioned. + """ + # Accept any batch shape. Flatten to (N, 3, 3) for clarity, then unflatten. + batch_shape = R.shape[:-2] + R_flat = R.reshape(-1, 3, 3) + + # Read commonly-used entries with explicit names for readability + R00 = R_flat[:, 0, 0] + # R01 = R_flat[:, 0, 1] + R02 = R_flat[:, 0, 2] + R10 = R_flat[:, 1, 0] + # R11 = R_flat[:, 1, 1] + R12 = R_flat[:, 1, 2] + R20 = R_flat[:, 2, 0] + R21 = R_flat[:, 2, 1] + R22 = R_flat[:, 2, 2] + + # Default (non-singular) extraction + zz = np.clip(R22, -1.0, 1.0) + betas = np.arccos(zz) + + # For Z–Y–Z, standard formulas away from the singular set + alphas = np.arctan2(R12, R02) + gammas = np.arctan2(R21, -R20) + + # Normalize into [0, 2π) + two_pi = 2.0 * np.pi + alphas = np.mod(alphas, two_pi) + gammas = np.mod(gammas, two_pi) + + # Gimbal-lock detection via sin(beta) + sinb = np.sin(betas) + near = np.abs(sinb) < eps + if np.any(near): + # Split the two singular bands using zz = cos(beta) + near_zero = near & (zz > 0) # beta≈0 + near_pi = near & (zz < 0) # beta≈pi + + if np.any(near_zero): + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from + # 2x2 block. + betas[near_zero] = 0.0 + gammas[near_zero] = 0.0 + alphas[near_zero] = np.arctan2(R10[near_zero], R00[near_zero]) + alphas[near_zero] = np.mod(alphas[near_zero], two_pi) + + if np.any(near_pi): + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on + # R00. + betas[near_pi] = np.pi + alphas[near_pi] = 0.0 + gammas[near_pi] = np.arctan2(R10[near_pi], -R00[near_pi]) + gammas[near_pi] = np.mod(gammas[near_pi], two_pi) + + # Unflatten back to the original batch shape + alphas = alphas.reshape(batch_shape) + betas = betas.reshape(batch_shape) + gammas = gammas.reshape(batch_shape) + return alphas, betas, gammas + + +def _l0_components_from_matrices(A: torch.Tensor) -> torch.Tensor: + """ + Extract the L=0 components from a (3, 3) tensor. + """ + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at + # the end + A = A.permute(0, 3, 1, 2) + # Test if the last two dimensions are (3, 3) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Initialize the output tensor for L=0 components to have 1 component in the last + # dimension + l0_A = torch.empty(A.shape[:-2] + (1,), dtype=A.dtype, device=A.device) + + # Compute the L=0 component as the trace of A + l0_A[..., 0] = A[..., 0, 0] + A[..., 1, 1] + A[..., 2, 2] + + l0_A = l0_A.permute(0, 2, 1) + return l0_A + + +def _l2_components_from_matrices(A: torch.Tensor) -> torch.Tensor: + """ + Extract the L=2 components from a (3, 3) tensor. + """ + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at + # the end + A = A.permute(0, 3, 1, 2) + # Test if the last two dimensions are (3, 3) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Initialize the output tensor for L=2 components to have 5 components in the last + # dimension + l2_A = torch.empty(A.shape[:-2] + (5,), dtype=A.dtype, device=A.device) + + l2_A[..., 0] = (A[..., 0, 1] + A[..., 1, 0]) / 2.0 + l2_A[..., 1] = (A[..., 1, 2] + A[..., 2, 1]) / 2.0 + l2_A[..., 2] = (2.0 * A[..., 2, 2] - A[..., 0, 0] - A[..., 1, 1]) / ( + (2.0) * np.sqrt(3.0) + ) + l2_A[..., 3] = (A[..., 0, 2] + A[..., 2, 0]) / 2.0 + l2_A[..., 4] = (A[..., 0, 0] - A[..., 1, 1]) / 2.0 + + l2_A = l2_A.permute(0, 2, 1) + + return l2_A + + +def _euler_angles_of_combined_rotation( + angles1: Tuple[np.ndarray, np.ndarray, np.ndarray], + angles2: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Given two sets of Euler angles (alpha, beta, gamma), returns the Euler angles + of all pairwise compositions + """ + + R1 = _rotations_from_angles(*angles1).as_matrix() # (N1, 3, 3) + R2 = _rotations_from_angles(*angles2).as_matrix() # (N2, 3, 3) + + # Broadcasted pairwise multiplication to shape (N1, N2, 3, 3): R1[p] @ R2[a] + R_product = R1[:, None, :, :] @ R2[None, :, :, :] + + # Extract Euler angles from the combined rotation matrices (robust to gimbal lock) + alpha, beta, gamma = _angles_from_rotations(R_product, eps=1e-6) + return alpha, beta, gamma + + +def _get_so3_character( + alphas: np.ndarray, + betas: np.ndarray, + gammas: np.ndarray, + o3_lambda: int, + tol: float = 1e-7, +) -> np.ndarray: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over SO(3). + + Uses a small-angle Taylor expansion for χ_l(ω) = sin((2l+1)t)/sin(t) with t = ω/2 + when |t| is very small, and a guarded ratio otherwise. + """ + # Compute half-angle t = ω/2 via Z–Y–Z relation: cos t = cos(β/2) cos((α+γ)/2) + cos_t = np.cos(betas / 2.0) * np.cos((alphas + gammas) / 2.0) + cos_t = np.clip(cos_t, -1.0, 1.0) + t = np.arccos(cos_t) + + # Output array + chi = np.empty_like(t) + + # Parameters for χ + L = o3_lambda + a = 2 * L + 1 + ll1 = L * (L + 1) + + small = np.abs(t) < tol + if np.any(small): + # Series up to t^4: χ ≈ a [1 - (2/3) ℓ(ℓ+1) t^2 + (1/45) ℓ(ℓ+1)(3ℓ^2+3ℓ-1) t^4] + ts = t[small] + t2 = ts * ts + coeff4 = ll1 * (3 * L * L + 3 * L - 1) + chi[small] = a * ( + 1.0 - (2.0 / 3.0) * ll1 * t2 + (1.0 / 45.0) * coeff4 * t2 * t2 + ) + + # Large-angle (or not-so-small) branch: safe ratio with guard + large = ~small + if np.any(large): + tl = t[large] + sin_t = np.sin(tl) + numer = np.sin(a * tl) + mask = np.abs(sin_t) >= tol + out = np.empty_like(tl) + np.divide(numer, sin_t, out=out, where=mask) + out[~mask] = a # exact limit as t -> 0 + chi[large] = out + + return chi + + +def compute_characters( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], + inverse_angles: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Tuple[Dict[int, torch.Tensor], Dict[str, torch.Tensor]]: + alpha, beta, gamma = _euler_angles_of_combined_rotation(angles, inverse_angles) + + so3_characters = { + o3_lambda: _get_so3_character(alpha, beta, gamma, o3_lambda) + for o3_lambda in range(o3_lambda_max + 1) + } + + pso3_characters = {} + for o3_lambda in range(o3_lambda_max + 1): + for o3_sigma in [-1, +1]: + pso3_characters[f"{o3_lambda}_{o3_sigma}"] = ( + o3_sigma * ((-1) ** o3_lambda) * so3_characters[o3_lambda] + ) + + so3_characters = { + key: torch.from_numpy(value) for key, value in so3_characters.items() + } + pso3_characters = { + key: torch.from_numpy(value) for key, value in pso3_characters.items() + } + + return so3_characters, pso3_characters + + +def _character_convolution( + chi: torch.Tensor, block1: TensorBlock, block2: TensorBlock, w: torch.Tensor +) -> TensorBlock: + """ + Compute the character convolution of a block containing SO(3)-sampled tensors. + Then contract with another block. + """ + samples = block1.samples + assert samples.names[0] == "so3_rotation" + n_rot = chi.size(0) + components = block1.components + properties = block1.properties + values = block1.values + chi = chi.to(dtype=values.dtype, device=values.device) + n_rot = chi.size(1) + weight = w.to(dtype=values.dtype, device=values.device) + + split_sizes = torch.bincount(samples.values[:, 1]).tolist() + split_by_system = torch.split(values, split_sizes, dim=0) + tensor_list: List[torch.Tensor] = [] + for split_tensor, size in zip(split_by_system, split_sizes, strict=True): + split_size = [size // n_rot] * n_rot + split_by_rotation = torch.stack(torch.split(split_tensor, split_size, dim=0)) + tensor_list.append(split_by_rotation) + split_by_rotation = torch.cat(tensor_list, dim=1) + reshaped_values = split_by_rotation + + # broadcast weights to match reshaped_values + view: List[int] = [] + view.append(-1) + for _ in range(reshaped_values.ndim - 1): + view.append(1) + weighted_values = weight.view(view) * reshaped_values + + # broadcast characters to match reshaped_values + contracted_shape: List[int] = [chi.shape[0]] + list(weighted_values.shape[1:]) + contracted_values = ( + chi @ weighted_values.reshape(weighted_values.shape[0], -1) + ).reshape(contracted_shape) + + values2 = block2.values + split_sizes = torch.bincount(block2.samples.values[:, 1]).tolist() + split_by_system = torch.split(values2, split_sizes, dim=0) + tensor_list: List[torch.Tensor] = [] + for split_tensor, size in zip(split_by_system, split_sizes, strict=True): + split_size = [size // n_rot] * n_rot + split_by_rotation = torch.stack(torch.split(split_tensor, split_size, dim=0)) + tensor_list.append(split_by_rotation) + split_by_rotation = torch.cat(tensor_list, dim=1) + reshaped_values2 = split_by_rotation + + # broadcast weights to match reshaped_values2 + view: List[int] = [] + view.append(-1) + for _ in range(reshaped_values2.ndim - 1): + view.append(1) + weighted_values2 = weight.view(view) * reshaped_values2 + + # contract weighted_values2 with contracted_values + contracted_values = torch.einsum( + "i...,i...->...", + weighted_values2, + contracted_values, + ) + + names: List[str] = [] + for name in samples.names: + if name != "so3_rotation": + names.append(name) + new_block = TensorBlock( + samples=Labels(names, samples.values[samples.values[:, 0] == 0][:, 1:]), + components=components, + properties=properties, + values=contracted_values, + ) + + return new_block + + +def decompose_energy_tensor( + tensor_dict: Dict[str, TensorMap], + device: torch.device, +) -> Dict[str, TensorMap]: + """ + Decompose energy tensor into its L=0 irreducible representation. + + Energy is a scalar, so it lives entirely in the L=0 sector. This function + adds an ``o3_mu`` component axis with a single m=0 entry to make the format + consistent with higher-order decompositions. + + :param tensor_dict: dictionary of TensorMaps (modified in place) + :param device: device for label tensors + :return: the same dictionary with ``"energy"`` replaced by ``"energy_l0"`` + """ + if "energy" not in tensor_dict: + return tensor_dict + + tensor = tensor_dict["energy"] + tensor_dict["energy_l0"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.unsqueeze(1), + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor([[0]], device=device, dtype=torch.int32), + ) + ], + properties=block.properties, + ) + for block in tensor + ], + ) + tensor_dict.pop("energy") + return tensor_dict + + +def decompose_forces_tensor( + tensor_dict: Dict[str, TensorMap], +) -> Dict[str, TensorMap]: + """ + Decompose forces tensors into L=1 irreducible representations. + + Forces are Cartesian vectors (x, y, z). This reorders them to spherical + component order (y, z, x) → (m=-1, m=0, m=1) via a cyclic roll, and + labels the component axis as ``o3_mu``. + + Handles both ``"forces"`` (conservative) and ``"non_conservative_forces"`` keys. + + :param tensor_dict: dictionary of TensorMaps (modified in place) + :return: the same dictionary with forces keys replaced by ``"..._l1"`` variants + """ + for key in ["forces", "non_conservative_forces"]: + if key not in tensor_dict: + continue + + tensor = tensor_dict[key] + tensor_dict[key + "_l1"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.roll(-1, 1), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-1, 2)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + for block in tensor + ], + ) + tensor_dict.pop(key) + return tensor_dict + + +def decompose_stress_tensor( + tensor_dict: Dict[str, TensorMap], +) -> Dict[str, TensorMap]: + """ + Decompose stress tensors into L=0 (trace) and L=2 (symmetric traceless) parts. + + The 3x3 stress tensor decomposes as: trace (L=0 scalar) + symmetric traceless + (L=2, 5 components). The antisymmetric part (L=1) is zero for physical stress. + + Handles both ``"stress"`` (conservative) and ``"non_conservative_stress"`` keys. + + :param tensor_dict: dictionary of TensorMaps (modified in place) + :return: the same dictionary with stress keys replaced by ``"..._l0"`` and + ``"..._l2"`` variants + """ + for key in ["stress", "non_conservative_stress"]: + if key not in tensor_dict: + continue + + tensor = tensor_dict[key] + blocks_l0 = [] + blocks_l2 = [] + for block in tensor.blocks(): + trace_values = _l0_components_from_matrices(block.values) + block_l0 = TensorBlock( + values=trace_values, + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor( + [[0]], device=block.values.device, dtype=torch.int32 + ), + ) + ], + properties=block.properties, + ) + blocks_l0.append(block_l0) + + block_l2 = TensorBlock( + values=_l2_components_from_matrices(block.values), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-2, 3)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + blocks_l2.append(block_l2) + + tensor_dict[key + "_l0"] = TensorMap(tensor.keys, blocks_l0) + tensor_dict[key + "_l2"] = TensorMap(tensor.keys, blocks_l2) + tensor_dict.pop(key) + + return tensor_dict + + +def decompose_tensors( + tensor_dict: Dict[str, TensorMap], + device: torch.device, +) -> Dict[str, TensorMap]: + """ + Decompose all tensors in the dictionary into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :param device: device for label tensors + :return: dictionary of TensorMaps with decomposed tensors + """ + tensor_dict = decompose_energy_tensor(tensor_dict, device) + tensor_dict = decompose_forces_tensor(tensor_dict) + tensor_dict = decompose_stress_tensor(tensor_dict) + return tensor_dict + + +def compute_norm_per_property( + tensor_dict: Dict[str, TensorMap], + so3_weights: torch.Tensor, +) -> Dict[str, TensorMap]: + """ + Compute the weighted squared norm per property of each tensor. + + For each output, computes the quadrature-weighted sum of squared values + over the O(3) grid, giving the squared norm in each irrep sector per property. + + :param tensor_dict: dictionary of TensorMaps with ``so3_rotation`` in samples + :param so3_weights: quadrature weights, shape ``(n_rotations,)`` + :return: dictionary of TensorMaps with componentwise squared norms + """ + norms: Dict[str, TensorMap] = {} + for name in tensor_dict: + tensor = tensor_dict[name] + norm_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + rot_ids = block.samples.column("so3_rotation") + + values_squared = block.values**2 + + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) + values_squared = 0.5 * so3_weights[rot_ids].view(view) * values_squared + + norm_blocks.append( + TensorBlock( + values=values_squared, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + + tensor_norm = TensorMap(tensor.keys, norm_blocks) + tensor_norm = mts.sum_over_samples( + tensor_norm.keys_to_samples("inversion"), ["inversion", "so3_rotation"] + ) + + norms[name + "_componentwise_norm_squared"] = tensor_norm + return norms + + +def compute_conv_integral( + tensor_dict: Dict[str, TensorMap], + so3_weights: torch.Tensor, + so3_characters: Dict[int, torch.Tensor], + pso3_characters: Dict[str, torch.Tensor], + max_o3_lambda_character: int, +) -> Dict[str, TensorMap]: + """ + Compute character convolution integrals over O(3) for each tensor. + + Projects each output onto O(3) irrep sectors by convolving with + the characters chi_{l,sigma}. The result measures how much of the + output's variance lives in each (l, sigma) sector. + + :param tensor_dict: dictionary of TensorMaps with rotation samples + :param so3_weights: quadrature weights + :param so3_characters: SO(3) characters, mapping l → tensor of shape (N_rot, N_rot) + :param pso3_characters: P*SO(3) characters, mapping "l_sigma" → tensor + :param max_o3_lambda_character: maximum angular momentum for projection + :return: dictionary of TensorMaps with character projections + """ + new_tensors: Dict[str, TensorMap] = {} + for name, tensor in tensor_dict.items(): + keys = tensor.keys + remaining_keys = Labels( + keys.names[:-1], + keys.values[keys.column("inversion") == 1][:, :-1], + ) + new_blocks: List[TensorBlock] = [] + new_keys: List[torch.Tensor] = [] + for key_values in remaining_keys.values: + key_to_match_plus: Dict[str, int] = {} + key_to_match_minus: Dict[str, int] = {} + for k, v in zip(remaining_keys.names, key_values, strict=True): + key_to_match_plus[k] = int(v) + key_to_match_minus[k] = int(v) + key_to_match_plus["inversion"] = 1 + key_to_match_minus["inversion"] = -1 + so3_block = tensor.block(key_to_match_plus) + pso3_block = tensor.block(key_to_match_minus) + + for o3_lambda in range(max_o3_lambda_character + 1): + so3_chi = so3_characters[o3_lambda] + first_term = _character_convolution( + so3_chi, so3_block, so3_block, so3_weights + ) + second_term = _character_convolution( + so3_chi, pso3_block, pso3_block, so3_weights + ) + for o3_sigma in [1, -1]: + label = str(o3_lambda) + "_" + str(o3_sigma) + pso3_chi = pso3_characters[label] + third_term = _character_convolution( + pso3_chi, pso3_block, so3_block, so3_weights + ) + block = TensorBlock( + samples=first_term.samples, + components=first_term.components, + properties=first_term.properties, + values=( + 0.25 * (first_term.values + second_term.values) + + 0.5 * third_term.values + ) + * (2 * o3_lambda + 1), + ) + new_blocks.append(block) + new_keys.append( + torch.cat( + [ + key_values, + torch.tensor( + [o3_lambda, o3_sigma], + device=key_values.device, + dtype=key_values.dtype, + ), + ] + ) + ) + key_names: List[str] = [] + for key_name in tensor.keys.names: + if key_name != "inversion": + key_names.append(key_name) + new_tensor = TensorMap( + Labels( + key_names + ["chi_lambda", "chi_sigma"], + torch.stack(new_keys), + ), + new_blocks, + ) + if "_" in new_tensor.keys.names: + new_tensor = mts.remove_dimension(new_tensor, "keys", "_") + new_tensors[name + "_character_projection"] = new_tensor + return new_tensors + + +class SymmetrizedModel(torch.nn.Module): + r""" + Wrapper around an atomistic model that symmetrizes its outputs over :math:`O(3)` + and computes equivariance metrics. + + The model is evaluated over a quadrature grid on :math:`O(3)`, constructed from a + Lebedev grid supplemented by in-plane rotations. For each sampled group element, the + model outputs are "back-rotated" according to the known :math:`O(3)` action + appropriate for their tensorial type (scalar, vector, tensor, etc.). Averaging these + back-rotated predictions over the quadrature grid yields fully + :math:`O(3)`-symmetrized outputs. In addition, two complementary equivariance + metrics are computed: + + 1. Variance under :math:`O(3)` of the back-rotated outputs. + + For a perfectly equivariant model, the back-rotated output :math:`x(g)` is + independent of the group element :math:`g`. Deviations from perfect equivariance + are quantified by the difference between the average squared norm over + :math:`O(3)` and the squared norm of the :math:`O(3)`-averaged output: + + .. math:: + + \mathrm{Var}_{O(3)}[x] + = + \left\langle \,\| x(g) \|^{2} \,\right\rangle_{O(3)} + - + \left\| \left\langle x(g) \right\rangle_{O(3)} \right\|^{2} . + + Here, :math:`\|\cdot\|` denotes the Euclidean norm over the ``component`` axis, + and :math:`\langle \cdot \rangle_{O(3)}` denotes averaging over the quadrature + grid. This quantity is the squared norm of the component orthogonal to the + perfectly equivariant subspace and therefore provides a scalar measure of the + deviation from exact equivariance. + + 2. Decomposition into isotypical components of :math:`O(3)`. + + Each output component may be viewed as a scalar function on :math:`O(3)`, + which can be decomposed into isotypical components labeled by the irreducible + representations :math:`\ell,\sigma` of :math:`O(3)`. The projection onto the + :math:`(\ell,\sigma)`-th isotypical subspace is computed as a convolution with + the corresponding character :math:`\chi_{\ell}`: + + .. math:: + + (P_{\ell,\sigma} x)(g) + = + \int_{O(3)} \chi_{\ell,\sigma}(h^{-1} g)\, x(h)\, \mathrm{d}\mu(h). + + Its squared :math:`L^{2}` norm over :math:`O(3)` is + + .. math:: + + \| P_{\ell,\sigma} x \|^{2} + = + \left\langle \, | (P_{\ell,\sigma} x)(g) |^{2} \, \right\rangle_{O(3)} . + + These quantities describe how the model output is distributed across the + different :math:`O(3)` irreducible sectors. The complementary component, + orthogonal to all isotypical subspaces, is given by + + .. math:: + + \| x \|^{2} + - + \sum_{\ell,\sigma} \| P_{\ell,\sigma} x \|^{2} , + + and provides a refined measure of the deviation from lying entirely within any + prescribed set of :math:`O(3)` irreducible representations. + + :param base_model: atomistic model to symmetrize + :param max_o3_lambda: maximum O(3) angular momentum the grid integrates exactly + :param batch_size: number of rotations to evaluate in a single batch + :param max_o3_lambda_character: maximum O(3) angular momentum for character + projections. If None, set to ``max_o3_lambda``. + """ + + def __init__( + self, + base_model, + max_o3_lambda_character: int, + max_o3_lambda_target: int, + batch_size: int = 32, + max_o3_lambda_grid: Optional[int] = None, + ): + super().__init__() + self.base_model = base_model + + try: + ref_param = next(base_model.parameters()) + device = ref_param.device + dtype = ref_param.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + self.max_o3_lambda_target = max_o3_lambda_target + self.batch_size = batch_size + if max_o3_lambda_grid is None: + max_o3_lambda_grid = int(2 * max_o3_lambda_character + 1) + self.max_o3_lambda_grid = max_o3_lambda_grid + self.max_o3_lambda_character = max_o3_lambda_character + + # Compute grid (unchanged) + lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) + if lebedev_order < 2 * self.max_o3_lambda_character: + warnings.warn( + "Lebedev order may be insufficient for character projections.", + stacklevel=2, + ) + alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( + lebedev_order, n_inplane_rotations + ) + so3_weights = torch.from_numpy(w_so3).to(device=device, dtype=dtype) + self.register_buffer("so3_weights", so3_weights) + + so3_rotations = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ).to(device=device, dtype=dtype) + self.register_buffer("so3_rotations", so3_rotations) + self.n_so3_rotations = self.so3_rotations.size(0) + + angles_inverse_rotations = (np.pi - gamma, beta, np.pi - alpha) + so3_inverse_rotations = torch.from_numpy( + _rotations_from_angles(*angles_inverse_rotations).as_matrix() + ).to(device=device, dtype=dtype) + self.register_buffer("so3_inverse_rotations", so3_inverse_rotations) + + self._wigner_D_inverse_jit: Dict[int, torch.Tensor] = {} + self._so3_characters_jit: Dict[int, torch.Tensor] = {} + self._pso3_characters_jit: Dict[str, torch.Tensor] = {} + # Since Wigner D matrices are stored in dicts, we need a bit of gymnastics to + # register the buffers + raw_wigner = _compute_real_wigner_matrices( + self.max_o3_lambda_target, angles_inverse_rotations + ) + self._wigner_D_inverse_names: Dict[int, str] = {} + for ell, D in raw_wigner.items(): + if isinstance(D, np.ndarray): + D = torch.from_numpy(D) + D = D.to(dtype=dtype, device=device) + name = f"wigner_D_inverse_rotations_l{ell}" + self.register_buffer(name, D) + self._wigner_D_inverse_names[ell] = name + # TorchScript dict view uses the same tensor + self._wigner_D_inverse_jit[ell] = D + + # Compute characters + so3_characters, pso3_characters = compute_characters( + self.max_o3_lambda_character, + (alpha, beta, gamma), + angles_inverse_rotations, + ) + self._so3_char_names: Dict[int, str] = {} + self._pso3_char_names: Dict[str, str] = {} + + # Since characters are stored in dicts, we need a bit of gymnastics to + # register the buffers + for ell, ch in so3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) + + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"so3_characters_l{ell}" + self.register_buffer(name, ch) + self._so3_char_names[ell] = name + + self._so3_characters_jit = {} # kill the CUDA dict cache + + for ell, ch in pso3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) + + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"pso3_characters_l{ell}" + self.register_buffer(name, ch) + self._pso3_char_names[ell] = name + + self._pso3_characters_jit = {} + + @torch.jit.ignore + def _wigner_D_inverse_dict(self) -> Dict[int, torch.Tensor]: + return { + ell: getattr(self, name) + for ell, name in self._wigner_D_inverse_names.items() + } + + @property + def wigner_D_inverse_rotations(self) -> Dict[int, torch.Tensor]: + # Python-only nice view + return self._wigner_D_inverse_dict() + + @torch.jit.ignore + def _so3_characters_dict(self) -> Dict[int, torch.Tensor]: + return {ell: getattr(self, name) for ell, name in self._so3_char_names.items()} + + @property + def so3_characters(self) -> Dict[int, torch.Tensor]: + # Python-only nice view + return self._so3_characters_dict() + + @torch.jit.ignore + def _pso3_characters_dict(self) -> Dict[str, torch.Tensor]: + return {key: getattr(self, name) for key, name in self._pso3_char_names.items()} + + @property + def pso3_characters(self) -> Dict[str, torch.Tensor]: + # Python-only nice view + return self._pso3_characters_dict() + + def _get_wigner_D_inverse(self, ell: int) -> torch.Tensor: + return self._wigner_D_inverse_jit[ell] + + def _get_so3_character(self, o3_lambda: int) -> torch.Tensor: + name = self._so3_char_names[o3_lambda] + ch_cpu = getattr(self, name) + + # follow the base model device/dtype + try: + ref = next(self.base_model.parameters()) + device = ref.device + dtype = ref.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + return ch_cpu.to(device=device, dtype=dtype, non_blocking=True) + + def _get_pso3_character(self, o3_lambda: int, o3_sigma: int) -> torch.Tensor: + label = str(o3_lambda) + "_" + str(o3_sigma) + name = self._pso3_char_names[label] + ch_cpu = getattr(self, name) + + try: + ref = next(self.base_model.parameters()) + device = ref.device + dtype = ref.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + return ch_cpu.to(device=device, dtype=dtype, non_blocking=True) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + project_tokens: bool = False, + compute_gradients: bool = False, + ) -> Dict[str, TensorMap]: + """ + Symmetrize the model outputs over :math:`O(3)` and compute equivariance + metrics. + + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :param project_tokens: if True, also compute character projections + :param compute_gradients: if True, compute conservative forces and stress + via autograd. When False (default), the grid evaluation runs under + ``torch.no_grad()`` to save memory. + :return: dictionary with symmetrized outputs and equivariance metrics + """ + device = self.so3_weights.device + + with torch.no_grad() if not compute_gradients else torch.enable_grad(): + transformed_outputs, backtransformed_outputs = self._eval_over_grid( + systems, + outputs, + selected_atoms, + return_transformed=project_tokens, + compute_gradients=compute_gradients, + ) + + if not compute_gradients: + # Move to CPU to free GPU memory; all downstream ops are pure + # tensor algebra that runs fine on CPU + transformed_outputs = { + k: v.to(device="cpu") for k, v in transformed_outputs.items() + } + backtransformed_outputs = { + k: v.to(device="cpu") for k, v in backtransformed_outputs.items() + } + + decompose_device = torch.device("cpu") if not compute_gradients else device + transformed_outputs = decompose_tensors(transformed_outputs, decompose_device) + backtransformed_outputs = decompose_tensors( + backtransformed_outputs, decompose_device + ) + + out_dict: Dict[str, TensorMap] = {} + + so3_weights = self.so3_weights + if not compute_gradients: + so3_weights = so3_weights.to(device="cpu") + + mean_var = symmetrize_over_grid(backtransformed_outputs, so3_weights) + for name, tensor in mean_var.items(): + out_dict[name] = tensor + + if not project_tokens: + return out_dict + + norms = compute_norm_per_property(transformed_outputs, so3_weights) + for name, tensor in norms.items(): + out_dict[name] = tensor + + so3_chars = self.so3_characters + pso3_chars = self.pso3_characters + if not compute_gradients: + so3_chars = {k: v.to(device="cpu") for k, v in so3_chars.items()} + pso3_chars = {k: v.to(device="cpu") for k, v in pso3_chars.items()} + convolution_integrals = compute_conv_integral( + transformed_outputs, + so3_weights, + so3_chars, + pso3_chars, + self.max_o3_lambda_character, + ) + for name, integral in convolution_integrals.items(): + out_dict[name] = integral + + return out_dict + + def _eval_over_grid( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + return_transformed: bool, + compute_gradients: bool = False, + ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :param return_transformed: if True, also return un-back-rotated outputs + :param compute_gradients: if True, compute forces/stress via autograd + :return: (transformed_outputs, backtransformed_outputs) dictionaries + """ + + results = evaluate_model_over_grid( + self.base_model, + self.batch_size, + self.so3_rotations, + self.so3_inverse_rotations, + self._wigner_D_inverse_jit, + return_transformed, + systems, + outputs, + selected_atoms, + compute_gradients=compute_gradients, + ) + + if return_transformed: + transformed_outputs_tensor, backtransformed_outputs_tensor = results + else: + backtransformed_outputs_tensor = results + transformed_outputs_tensor: Dict[str, TensorMap] = {} + + # TODO: possibly remove + if "energy" in transformed_outputs_tensor: + energy_tm = transformed_outputs_tensor["energy"] + if "atom" in energy_tm[0].samples.names: + # Sum over atoms while keeping system and rotation indices. + energy_total_tm = mts.sum_over_samples(energy_tm, ["atom"]) + transformed_outputs_tensor["energy_total"] = energy_total_tm + + if "energy" in backtransformed_outputs_tensor: + energy_tm_bt = backtransformed_outputs_tensor["energy"] + if "atom" in energy_tm_bt[0].samples.names: + energy_total_tm_bt = mts.sum_over_samples(energy_tm_bt, ["atom"]) + backtransformed_outputs_tensor["energy_total"] = energy_total_tm_bt + return transformed_outputs_tensor, backtransformed_outputs_tensor + + +def _evaluate_with_gradients( + model: ModelInterface, + system: System, + rotation: torch.Tensor, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + device: torch.device, + dtype: torch.dtype, +) -> Dict[str, TensorMap]: + """ + Evaluate model on a single rotated system and compute conservative forces/stress + via autograd. + + Forces are ``-dE/d(positions)`` in the rotated frame. Stress is computed via the + strain trick as ``(1/V) dE/d(strain)`` in the rotated frame. Both are packaged as + Cartesian TensorMaps suitable for back-rotation by the existing pipeline. + + :param model: atomistic model to evaluate + :param system: input system (original frame) + :param rotation: 3x3 rotation matrix (may include inversion) + :param outputs: model output specifications + :param selected_atoms: optional atom selection + :param device: device for tensors + :param dtype: dtype for tensors + :return: model output dict with added ``"forces"`` and (if periodic) ``"stress"`` + """ + n_atoms = system.positions.shape[0] + R = rotation.to(device=device, dtype=dtype) + + # Rotate positions (detached from original graph) and enable grad tracking + rotated_positions = (system.positions.detach() @ R.T).requires_grad_(True) + rotated_cell = system.cell.detach() @ R.T + + # Strain trick for stress (applied in the rotated frame) + has_cell = bool(torch.any(system.pbc).item()) + if has_cell: + strain = torch.eye(3, requires_grad=True, device=device, dtype=dtype) + final_positions = rotated_positions @ strain + final_cell = rotated_cell @ strain + else: + strain = None + final_positions = rotated_positions + final_cell = rotated_cell + + # Build transformed system + transformed = System( + types=system.types, + positions=final_positions, + cell=final_cell, + pbc=system.pbc, + ) + + # Copy and register neighbor lists for autograd + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + neighbors.values[:] = (neighbors.values.squeeze(-1) @ R.T).unsqueeze(-1) + register_autograd_neighbors(transformed, neighbors) + transformed.add_neighbor_list(options, neighbors) + + # Evaluate model + out = model([transformed], outputs, selected_atoms) + + if "energy" not in out: + raise ValueError("compute_gradients=True requires the model to output 'energy'") + energy_sum = out["energy"].block().values.sum() + + # Compute gradients via autograd + grad_targets = [rotated_positions] + if strain is not None: + grad_targets.append(strain) + grads = torch.autograd.grad(energy_sum, grad_targets, create_graph=False) + + # Forces: -dE/d(rotated_positions) in the rotated frame + forces_values = -grads[0] # (n_atoms, 3) + + key_labels = Labels( + names=["_"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ) + + forces_block = TensorBlock( + values=forces_values.unsqueeze(-1), # (n_atoms, 3, 1) + samples=Labels( + names=["system", "atom"], + values=torch.stack( + [ + torch.zeros(n_atoms, dtype=torch.int64, device=device), + torch.arange(n_atoms, dtype=torch.int64, device=device), + ], + dim=1, + ), + ), + components=[ + Labels( + "xyz", + torch.arange(3, dtype=torch.int64, device=device).reshape(-1, 1), + ) + ], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ), + ) + out["forces"] = TensorMap(key_labels, [forces_block]) + + # Stress: (1/V) dE/d(strain) in the rotated frame + if strain is not None: + volume = torch.abs(torch.linalg.det(system.cell.detach())) + stress_values = grads[1] / volume # (3, 3) + + stress_block = TensorBlock( + values=stress_values.unsqueeze(0).unsqueeze(-1), # (1, 3, 3, 1) + samples=Labels( + names=["system"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ), + components=[ + Labels( + "xyz_1", + torch.arange(3, dtype=torch.int64, device=device).reshape(-1, 1), + ), + Labels( + "xyz_2", + torch.arange(3, dtype=torch.int64, device=device).reshape(-1, 1), + ), + ], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ), + ) + out["stress"] = TensorMap(key_labels, [stress_block]) + + return out + + +def evaluate_model_over_grid( + model: ModelInterface, + batch_size: int, + so3_rotations: torch.Tensor, + so3_rotations_inverse: torch.Tensor, + wigner_D_inverse: Dict[int, torch.Tensor], + return_transformed: bool, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + compute_gradients: bool = False, +) -> Dict[str, TensorMap] | Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Evaluate the model on rotated copies of the input systems over an O(3) quadrature + grid, and optionally back-rotate the outputs. + + This function does **not** manage gradient context (``torch.no_grad`` etc.). + Callers are responsible for wrapping in the appropriate context. + + :param model: atomistic model to evaluate + :param batch_size: number of rotations to evaluate in a single batch + :param so3_rotations: SO(3) rotation matrices, shape ``(N, 3, 3)`` + :param so3_rotations_inverse: inverse rotation matrices, shape ``(N, 3, 3)`` + :param wigner_D_inverse: Wigner D matrices for back-rotation, mapping l to tensor + :param return_transformed: if True, also return un-back-rotated outputs + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to compute + :param selected_atoms: optional Labels specifying which atoms to consider + :param compute_gradients: if True, compute conservative forces and stress via + autograd on each rotated evaluation. Results are added as ``"forces"`` and + ``"stress"`` keys (distinct from any ``"non_conservative_*"`` model outputs). + :return: back-rotated outputs, or (transformed, back-rotated) if + ``return_transformed=True`` + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs: Dict[str, List[Dict[int, TensorMap]]] = {} + output_names = list(outputs.keys()) + if compute_gradients: + output_names = list(set(output_names + ["forces"])) + if any(bool(torch.any(s.pbc).item()) for s in systems): + output_names = list(set(output_names + ["stress"])) + for name in output_names: + lst: List[Dict[int, TensorMap]] = [] + for _ in systems: + d: Dict[int, TensorMap] = {} + lst.append(d) + transformed_outputs[name] = lst + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs: List[Dict[str, TensorMap]] = [] + + if compute_gradients: + # Process one rotation at a time for per-rotation autograd + for R in so3_rotations: + rotation = inversion * R.to(device=device, dtype=dtype) + out = _evaluate_with_gradients( + model, + system, + rotation, + outputs, + selected_atoms, + device, + dtype, + ) + rotation_outputs.append(out) + effective_batch_size = 1 + else: + for batch_start in range(0, len(so3_rotations), batch_size): + transformed_systems = [ + _transform_system( + system, + inversion * R.to(device=device, dtype=dtype), + ) + for R in so3_rotations[batch_start : batch_start + batch_size] + ] + out = model( + transformed_systems, + outputs, + selected_atoms, + ) + rotation_outputs.append(out) + effective_batch_size = batch_size + + # Combine batch outputs + for name in output_names: + if name not in rotation_outputs[0]: + continue + combined_: List[TensorMap] = [r[name] for r in rotation_outputs] + combined = mts.join( + combined_, + "samples", + add_dimension="batch_rotation", + ) + if "batch_rotation" in combined[0].samples.names: + blocks: List[TensorBlock] = [] + for block in combined.blocks(): + batch_id = block.samples.column("batch_rotation") + rot_id = block.samples.column("system") + new_sample_values = block.samples.values[:, :-1] + new_sample_values[:, 0] = ( + batch_id * effective_batch_size + rot_id + ) + blocks.append( + TensorBlock( + values=block.values, + samples=Labels( + block.samples.names[:-1], + new_sample_values, + ), + components=block.components, + properties=block.properties, + ) + ) + combined = TensorMap(combined.keys, blocks) + transformed_outputs[name][i_sys][inversion] = combined + + backtransformed_outputs = backtransform_outputs( + transformed_outputs, systems, so3_rotations_inverse, wigner_D_inverse + ) + backtransformed_outputs_tensor = to_metatensor(backtransformed_outputs, systems) + + if return_transformed: + transformed_outputs_tensor = to_metatensor(transformed_outputs, systems) + return transformed_outputs_tensor, backtransformed_outputs_tensor + else: + transformed_outputs_tensor: Dict[str, TensorMap] = {} + return backtransformed_outputs_tensor + + +def to_metatensor( + tensor_dict: Dict[str, TensorMap], systems: List[System] +) -> Dict[str, TensorMap]: + """ + Convert the outputs of the model evaluated on rotated systems to a single + TensorMap per property, with appropriate dimensions for O(3) symmetrization. + """ + + out_tensor_dict: Dict[str, TensorMap] = {} + # Massage outputs to have desired shape + for name in tensor_dict: + joined_plus = mts.join( + [tensor_dict[name][i_sys][1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [tensor_dict[name][i_sys][-1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") + + if "phys_system" in joined[0].samples.names: + joined = mts.rename_dimension(joined, "samples", "phys_system", "system") + else: + joined = mts.insert_dimension( + joined, + "samples", + 1, + "system", + torch.zeros( + joined[0].samples.values.shape[0], + dtype=torch.long, + device=joined[0].samples.values.device, + ), + ) + if "atom" in joined[0].samples.names or "first_atom" in joined[0].samples.names: + perm = _permute_system_before_atom(joined[0].samples.names) + joined = mts.permute_dimensions(joined, "samples", perm) + out_tensor_dict[name] = joined + + return out_tensor_dict + + +def backtransform_outputs( + tensor_dict: Dict[str, List[Dict[int, TensorMap]]], + systems: List[System], + so3_rotations_inverse: torch.Tensor, + wigner_D_inverse: Dict[int, torch.Tensor], +) -> Dict[str, List[Dict[int, TensorMap]]]: + """ + Given the outputs of the model evaluated on rotated systems, backtransform them to + the original frame according to the equivariance labels in the TensorMap keys. + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + backtransformed_tensor_dict: Dict[str, List[Dict[int, TensorMap]]] = {} + for name in tensor_dict: + lst: List[Dict[int, TensorMap]] = [] + for _ in systems: + d: Dict[int, TensorMap] = {} + lst.append(d) + backtransformed_tensor_dict[name] = lst + + n_rot = so3_rotations_inverse.size(0) + for name in tensor_dict: + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + tensor = tensor_dict[name][i_sys][inversion] + wigner_dict: Dict[int, List[torch.Tensor]] = {} + for ell in wigner_D_inverse: + wigner_dict[ell] = ( + wigner_D_inverse[ell].to(device=device, dtype=dtype).unbind(0) + ) + + _, backtransformed, _ = _apply_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + so3_rotations_inverse.to(device=device, dtype=dtype) + * inversion + ).unbind(0) + ), + wigner_dict, + ) + backtransformed_tensor_dict[name][i_sys][inversion] = backtransformed[ + name + ] + return backtransformed_tensor_dict + + +def _permute_system_before_atom(labels: List[str]) -> List[int]: + # find positions + sys_idx = -1 + atom_idx = -1 + for i in range(len(labels)): + if labels[i] == "system": + sys_idx = i + elif labels[i] == "atom": + atom_idx = i + elif labels[i] == "first_atom": + atom_idx = i + + # identity permutation + perm = list(range(len(labels))) + + # reorder only if both present and system is after atom + if sys_idx != -1 and atom_idx != -1 and sys_idx > atom_idx: + v = perm[sys_idx] + # remove system + for k in range(sys_idx, len(perm) - 1): + perm[k] = perm[k + 1] + perm.pop() + # insert before atom + perm.insert(atom_idx, v) + + return perm + + +def symmetrize_over_grid( + tensor_dict: Dict[str, TensorMap], + so3_weights: torch.Tensor, +) -> Dict[str, TensorMap]: + """ + Compute the mean and variance of the outputs over O(3). + + :param tensor_dict: dictionary of TensorMaps with rotated and backtransformed + outputs to compute mean, variance, and norm squared for + :param so3_weights: weights of the SO(3) quadrature + :return: dictionary of TensorMaps with mean, variance, and norm squared + """ + mean_var: Dict[str, TensorMap] = {} + for name in tensor_dict: + # cannot compute a mean or variance as these have no known behaviour under + # rotations + if "features" in name: + continue + tensor = tensor_dict[name] + mean_blocks: List[TensorBlock] = [] + second_moment_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + rot_ids = block.samples.column("so3_rotation") + + values = block.values + if values.ndim > 2: + dims: List[int] = [] + for i in range(1, values.ndim - 1): + dims.append(i) + values_squared = torch.sum(values**2, dim=dims) + else: + values_squared = values**2 + + view: List[int] = [] + view.append(values.size(0)) + for _ in range(values.ndim - 1): + view.append(1) + values = 0.5 * so3_weights[rot_ids].view(view) * values + + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) + values_squared = 0.5 * so3_weights[rot_ids].view(view) * values_squared + + mean_blocks.append( + TensorBlock( + values=values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + second_moment_blocks.append( + TensorBlock( + values=values_squared, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + + # Mean + tensor_mean = TensorMap(tensor.keys, mean_blocks) + tensor_mean = mts.sum_over_samples( + tensor_mean.keys_to_samples("inversion"), ["inversion", "so3_rotation"] + ) + + # Mean norm + mean_norm_squared_blocks: List[TensorBlock] = [] + for block in tensor_mean.blocks(): + vals = block.values + if vals.ndim > 2: + dims: List[int] = [] + for i in range(1, vals.ndim - 1): + dims.append(i) + vals = torch.sum(vals**2, dim=dims) + else: + vals = vals**2 + mean_norm_squared_blocks.append( + TensorBlock( + values=vals, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + tensor_mean_norm_squared = TensorMap(tensor_mean.keys, mean_norm_squared_blocks) + + # Second moment + tensor_second_moment = TensorMap(tensor.keys, second_moment_blocks) + tensor_second_moment = mts.sum_over_samples( + tensor_second_moment.keys_to_samples("inversion"), + ["inversion", "so3_rotation"], + ) + + # Variance + tensor_variance = mts.subtract(tensor_second_moment, tensor_mean_norm_squared) + + mean_var[name + "_mean"] = tensor_mean + mean_var[name + "_norm_squared"] = tensor_second_moment + mean_var[name + "_var"] = tensor_variance + return mean_var diff --git a/python/metatomic_torch/tests/symmetrized_model.py b/python/metatomic_torch/tests/symmetrized_model.py new file mode 100644 index 00000000..740ea385 --- /dev/null +++ b/python/metatomic_torch/tests/symmetrized_model.py @@ -0,0 +1,411 @@ +"""Tests for symmetrized_model.py standalone functions and SymmetrizedModel class.""" + +from typing import Dict, List, Optional + +import numpy as np +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from scipy.spatial.transform import Rotation + +from metatomic.torch import ModelOutput, System +from metatomic.torch.symmetrized_model import ( + _choose_quadrature, + _compute_real_wigner_matrices, + _evaluate_with_gradients, + _l0_components_from_matrices, + _l2_components_from_matrices, + _rotations_from_angles, + get_euler_angles_quadrature, +) + + +class TestL0Components: + """Test extraction of L=0 (trace) components from 3x3 matrices.""" + + def test_identity_trace(self): + # Identity matrix has trace 3. The function expects shape (a, 3, 3, b). + A = torch.eye(3, dtype=torch.float64).unsqueeze(0).unsqueeze(-1) + result = _l0_components_from_matrices(A) + assert result.shape == (1, 1, 1) + assert torch.allclose(result, torch.tensor([[[3.0]]], dtype=torch.float64)) + + def test_traceless_matrix(self): + # A traceless matrix should give L=0 = 0 + M = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=torch.float64, + ) + A = M.unsqueeze(0).unsqueeze(-1) + result = _l0_components_from_matrices(A) + assert torch.allclose( + result, torch.tensor([[[0.0]]], dtype=torch.float64), atol=1e-14 + ) + + def test_batch_dimensions(self): + # Test with batch size > 1 and multiple properties + batch = 5 + n_prop = 3 + A = torch.randn(batch, 3, 3, n_prop, dtype=torch.float64) + result = _l0_components_from_matrices(A) + assert result.shape == (batch, 1, n_prop) + for i in range(batch): + for p in range(n_prop): + expected_trace = A[i, 0, 0, p] + A[i, 1, 1, p] + A[i, 2, 2, p] + assert torch.allclose(result[i, 0, p], expected_trace, atol=1e-14) + + +class TestL2Components: + """Test extraction of L=2 (symmetric traceless) components from 3x3 matrices.""" + + def test_identity_gives_zero(self): + # Identity is proportional to L=0 only; L=2 components should be zero. + A = torch.eye(3, dtype=torch.float64).unsqueeze(0).unsqueeze(-1) + result = _l2_components_from_matrices(A) + assert result.shape == (1, 5, 1) + assert torch.allclose( + result, torch.zeros(1, 5, 1, dtype=torch.float64), atol=1e-14 + ) + + def test_diagonal_traceless(self): + # diag(1, -1, 0) is traceless and has known L=2 components + M = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=torch.float64, + ) + A = M.unsqueeze(0).unsqueeze(-1) + result = _l2_components_from_matrices(A) + assert result.shape == (1, 5, 1) + # m=0: (2*0 - 1 - (-1)) / (2*sqrt(3)) = 0 + assert torch.allclose(result[0, 2, 0], torch.tensor(0.0, dtype=torch.float64)) + # m=2 (last component): (1 - (-1)) / 2 = 1 + assert torch.allclose(result[0, 4, 0], torch.tensor(1.0, dtype=torch.float64)) + + def test_frobenius_norm_relation(self): + """For a symmetric traceless matrix S, the L=2 decomposition should satisfy + a norm relation: sum(c_i^2) relates to (1/2) * sum(S_ij * S_ji). + """ + # Build a symmetric traceless matrix + S = torch.tensor( + [[2.0, 1.0, 0.5], [1.0, -1.0, 0.3], [0.5, 0.3, -1.0]], + dtype=torch.float64, + ) + A = S.unsqueeze(0).unsqueeze(-1) + l2 = _l2_components_from_matrices(A) + l2_norm_sq = (l2**2).sum() + + # The L=2 norm squared should equal half the Frobenius norm of the + # symmetric part (since the decomposition extracts the symmetric part) + sym_S = 0.5 * (S + S.T) + half_frob = 0.5 * (sym_S**2).sum() + # They won't be exactly equal because S has an L=0 part too. + # But for a traceless symmetric matrix, L=0 is zero, so they match. + trace = S[0, 0] + S[1, 1] + S[2, 2] + assert abs(trace) < 1e-14, "Matrix should be traceless for this test" + assert torch.allclose(l2_norm_sq, half_frob, atol=1e-12) + + +class TestDecomposeStressRoundtrip: + """Test that L=0 + L=2 decomposition covers the symmetric part of a 3x3 tensor.""" + + def test_norm_conservation(self): + """The sum of L=0 and L=2 squared norms should equal + the Frobenius norm squared of the symmetrized matrix.""" + M = torch.randn(1, 3, 3, 1, dtype=torch.float64) + sym_M = 0.5 * (M + M.transpose(1, 2)) + + l0 = _l0_components_from_matrices(sym_M) + l2 = _l2_components_from_matrices(sym_M) + + # L=0 norm: trace^2 / 3 (the trace component carries norm trace^2/3 + # in the irrep normalization). Actually, the L=0 extraction returns + # the raw trace, and L=2 the 5 components. Let's check reconstruction. + trace_val = l0[0, 0, 0] + # Reconstruct L=0 part: (trace/3) * I + l0_matrix = (trace_val / 3.0) * torch.eye(3, dtype=torch.float64) + + # Reconstruct L=2 part from components + c = l2[0, :, 0] # 5 components: (m=-2, m=-1, m=0, m=1, m=2) + l2_matrix = torch.zeros(3, 3, dtype=torch.float64) + # Reverse of the extraction formulas: + l2_matrix[0, 1] = c[0] + l2_matrix[1, 0] = c[0] + l2_matrix[1, 2] = c[1] + l2_matrix[2, 1] = c[1] + l2_matrix[0, 2] = c[3] + l2_matrix[2, 0] = c[3] + l2_matrix[0, 0] = c[4] + c[2] * np.sqrt(3) / 3 * (-1) + l2_matrix[1, 1] = -c[4] + c[2] * np.sqrt(3) / 3 * (-1) + l2_matrix[2, 2] = c[2] * 2.0 * np.sqrt(3) / 3 + + reconstructed = l0_matrix + l2_matrix + original_sym = sym_M[0, :, :, 0] + assert torch.allclose(reconstructed, original_sym, atol=1e-12) + + +class TestWignerD: + """Test properties of real Wigner D matrices.""" + + def test_orthogonality(self): + """D(R)^T D(R) = I for all ell.""" + rng = np.random.default_rng(42) + R = Rotation.random(5, random_state=rng) + angles = ( + np.zeros(5), + np.zeros(5), + np.zeros(5), + ) + # Use actual rotation angles + euler = R.as_euler("ZYZ") + angles = (euler[:, 0], euler[:, 1], euler[:, 2]) + + l_max = 4 + wigner = _compute_real_wigner_matrices(l_max, angles) + for ell in range(l_max + 1): + D = wigner[ell] # shape (5, 2l+1, 2l+1) + for i in range(5): + Di = D[i] + product = Di.T @ Di + identity = torch.eye(2 * ell + 1, dtype=Di.dtype) + assert torch.allclose(product, identity, atol=1e-10), ( + f"D^T D != I for ell={ell}, rotation {i}" + ) + + def test_identity_rotation(self): + """D(identity) = I for all ell.""" + angles = (np.array([0.0]), np.array([0.0]), np.array([0.0])) + l_max = 4 + wigner = _compute_real_wigner_matrices(l_max, angles) + for ell in range(l_max + 1): + D = wigner[ell][0] + identity = torch.eye(2 * ell + 1, dtype=D.dtype) + assert torch.allclose(D, identity, atol=1e-10), ( + f"D(identity) != I for ell={ell}" + ) + + def test_composition(self): + """D(R1) @ D(R2) ≈ D(R1 @ R2) for random rotations.""" + rng = np.random.default_rng(123) + R1 = Rotation.random(random_state=rng) + R2 = Rotation.random(random_state=rng) + R12 = R1 * R2 + + l_max = 3 + e1 = np.atleast_2d(R1.as_euler("ZYZ")) + e2 = np.atleast_2d(R2.as_euler("ZYZ")) + e12 = np.atleast_2d(R12.as_euler("ZYZ")) + + D1 = _compute_real_wigner_matrices(l_max, (e1[:, 0], e1[:, 1], e1[:, 2])) + D2 = _compute_real_wigner_matrices(l_max, (e2[:, 0], e2[:, 1], e2[:, 2])) + D12 = _compute_real_wigner_matrices(l_max, (e12[:, 0], e12[:, 1], e12[:, 2])) + + for ell in range(l_max + 1): + product = D1[ell][0] @ D2[ell][0] + expected = D12[ell][0] + assert torch.allclose(product, expected, atol=1e-10), ( + f"D(R1)D(R2) != D(R1R2) for ell={ell}" + ) + + +class TestQuadrature: + """Test quadrature weights and grid properties.""" + + def test_weights_sum(self): + """Quadrature weights should sum to 1 (normalized Haar measure on SO(3)).""" + for L_max in [3, 5, 7]: + lebedev_order, n_inplane = _choose_quadrature(L_max) + _, _, _, w = get_euler_angles_quadrature(lebedev_order, n_inplane) + # The weights are w_i / (4*pi*K) repeated K times, where w_i sum to 4*pi + # So total sum = sum(w_i)/(4*pi*K) * K = sum(w_i)/(4*pi) = 1 + assert np.allclose(w.sum(), 1.0, atol=1e-12), ( + f"Weights don't sum to 1 for L_max={L_max}: sum={w.sum()}" + ) + + def test_choose_quadrature_monotone(self): + """Higher L_max should give equal or larger quadrature grids.""" + prev_n = 0 + for L_max in [3, 5, 7, 11, 15]: + n, K = _choose_quadrature(L_max) + assert n >= prev_n + assert K == L_max + 1 + prev_n = n + + def test_rotations_are_proper(self): + """All rotation matrices from the quadrature should have det = +1.""" + lebedev_order, n_inplane = _choose_quadrature(5) + alpha, beta, gamma, _ = get_euler_angles_quadrature(lebedev_order, n_inplane) + R = _rotations_from_angles(alpha, beta, gamma) + matrices = R.as_matrix() + dets = np.linalg.det(matrices) + assert np.allclose(dets, 1.0, atol=1e-10) + + +class _QuadraticEnergyModel(torch.nn.Module): + """Minimal model where E = sum(positions^2). Analytical forces = -2*positions.""" + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + n_sys = len(systems) + energies = [] + for sys in systems: + energies.append(torch.sum(sys.positions**2)) + + key = Labels( + names=["_"], + values=torch.tensor([[0]], dtype=torch.int64), + ) + energy_block = TensorBlock( + values=torch.stack(energies).unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange(n_sys, dtype=torch.int64).unsqueeze(1), + ), + components=[], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64), + ), + ) + return {"energy": TensorMap(key, [energy_block])} + + def requested_neighbor_lists(self): + return [] + + +class TestGradientForces: + """Test conservative forces from autograd via _evaluate_with_gradients.""" + + def test_forces_identity_rotation(self): + """With identity rotation, forces should be -2*positions for E=sum(pos^2).""" + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float64 + ) + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + rotation = torch.eye(3, dtype=torch.float64) + outputs = {"energy": ModelOutput(per_atom=False)} + + out = _evaluate_with_gradients( + model, + system, + rotation, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "forces" in out + forces = out["forces"].block().values.squeeze(-1) # (2, 3) + expected = -2.0 * positions + assert torch.allclose(forces, expected, atol=1e-12) + + def test_forces_with_rotation(self): + """Forces in rotated frame should equal R @ (forces in lab frame). + For E=sum(pos^2), forces_lab = -2*pos_lab. + In rotated frame: forces_rot = -dE/d(pos_rot) where pos_rot = pos_lab @ R.T. + Since E = sum((pos_rot @ R)^2) = sum(pos_rot^2) (R is orthogonal), + forces_rot = -2*pos_rot = -2*(pos_lab @ R.T). + """ + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float64 + ) + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + # Random rotation + rng = np.random.default_rng(42) + R_scipy = Rotation.random(random_state=rng) + R = torch.tensor(R_scipy.as_matrix(), dtype=torch.float64) + outputs = {"energy": ModelOutput(per_atom=False)} + + out = _evaluate_with_gradients( + model, + system, + R, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + forces_rot = out["forces"].block().values.squeeze(-1) + expected_rot = -2.0 * (positions @ R.T) + assert torch.allclose(forces_rot, expected_rot, atol=1e-12) + + def test_stress_periodic_system(self): + """For a periodic system with E=sum(pos^2), check stress via strain trick. + + With strain trick: pos_final = pos_rot @ strain, so + E = sum((pos_rot @ strain)^2) = sum_i sum_a (sum_b pos_rot_ib * strain_ba)^2 + dE/d(strain_cd) = 2 * sum_i sum_a (pos_rot @ strain)_ia * pos_rot_ic * delta_da + = 2 * (pos_rot.T @ (pos_rot @ strain))_{ca} (at strain=I) + = 2 * pos_rot.T @ pos_rot + stress = (1/V) * dE/d(strain) = (2/V) * pos_rot.T @ pos_rot + """ + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64 + ) + cell = torch.eye(3, dtype=torch.float64) * 5.0 + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + R = torch.eye(3, dtype=torch.float64) + outputs = {"energy": ModelOutput(per_atom=False)} + + out = _evaluate_with_gradients( + model, + system, + R, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "stress" in out + stress = out["stress"].block().values.squeeze(0).squeeze(-1) # (3, 3) + volume = torch.abs(torch.linalg.det(cell)) + expected_stress = 2.0 * positions.T @ positions / volume + assert torch.allclose(stress, expected_stress, atol=1e-12) + + def test_no_stress_for_nonperiodic(self): + """Non-periodic systems should not produce stress output.""" + model = _QuadraticEnergyModel() + system = System( + types=torch.tensor([1]), + positions=torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float64), + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + R = torch.eye(3, dtype=torch.float64) + outputs = {"energy": ModelOutput(per_atom=False)} + + out = _evaluate_with_gradients( + model, + system, + R, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "forces" in out + assert "stress" not in out diff --git a/tox.ini b/tox.ini index cd0a281d..406a5781 100644 --- a/tox.ini +++ b/tox.ini @@ -150,6 +150,8 @@ deps = spglib # uncomment for testing nvalchemiops integration on GPU (requires Python 3.11+) # nvalchemi-toolkit-ops + metatrain + spherical changedir = python/metatomic_torch commands =