From 937f4735950120c8322fdfe3953d214a176c6930 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 11:32:12 -0400 Subject: [PATCH 1/3] fea: use upstream fairchem --- pyproject.toml | 2 +- tests/models/test_fairchem.py | 289 ---------------------------------- torch_sim/models/fairchem.py | 246 ++--------------------------- 3 files changed, 13 insertions(+), 524 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 90de3c4a..56c1da91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,8 @@ orb = ["orb-models>=0.6.0"] sevenn = ["sevenn[torchsim]>=0.12.1"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] nequip = ["nequip>=0.17.0"] +fairchem = ["fairchem-core @ git+https://github.com/facebookresearch/fairchem.git@main#subdirectory=packages/fairchem-core"] nequix = ["nequix[torch-sim]>=0.4.5"] -fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index d259e2ec..7937c692 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,23 +1,14 @@ import traceback import pytest -import torch -import torch_sim as ts from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test try: - from collections.abc import Callable - - from ase.build import bulk, fcc100, molecule - from fairchem.core.calculate.pretrained_mlip import ( - pretrained_checkpoint_path_from_name, - ) from huggingface_hub.utils._auth import get_token - import torch_sim as ts from torch_sim.models.fairchem import FairChemModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): @@ -33,205 +24,6 @@ def eqv2_uma_model_pbc() -> FairChemModel: return FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"]) -def test_task_initialization(task_name: str) -> None: - """Test that different UMA task names work correctly.""" - model = FairChemModel( - model="uma-s-1p1", task_name=task_name, device=torch.device("cpu") - ) - assert model.task_name - assert str(model.task_name.value) == task_name - assert hasattr(model, "predictor") - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize( - ("task_name", "systems_func"), - [ - ( - "omat", - lambda: [ - bulk("Si", "diamond", a=5.43), - bulk("Al", "fcc", a=4.05), - bulk("Fe", "bcc", a=2.87), - bulk("Cu", "fcc", a=3.61), - ], - ), - ( - "omol", - lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")], - ), - ], -) -def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None: - """Test batching multiple systems with the same task.""" - systems = systems_func() - - # Add molecular properties for molecules - if task_name == "omol": - for mol in systems: - mol.info |= {"charge": 0, "spin": 1} - - model = FairChemModel(model="uma-s-1p1", task_name=task_name, device=DEVICE) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - # Check batch dimensions - assert results["energy"].shape == (4,) - assert results["forces"].shape[0] == sum(len(s) for s in systems) - assert results["forces"].shape[1] == 3 - - # Check that different systems have different energies - energies = results["energy"] - uniq_energies = torch.unique(energies, dim=0) - assert len(uniq_energies) > 1, "Different systems should have different energies" - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_heterogeneous_tasks() -> None: - """Test different task types work with appropriate systems.""" - # Test molecule, material, and catalysis systems separately - test_cases = [ - ("omol", [molecule("H2O")]), - ("omat", [bulk("Pt", cubic=True)]), - ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]), - ] - - for task_name, systems in test_cases: - if task_name == "omol": - systems[0].info |= {"charge": 0, "spin": 1} - - model = FairChemModel( - model="uma-s-1p1", - task_name=task_name, - device=DEVICE, - ) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - assert results["energy"].shape[0] == 1 - assert results["forces"].dim() == 2 - assert results["forces"].shape[1] == 3 - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize( - ("systems_func", "expected_count"), - [ - (lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system - ( - lambda: [ - bulk("H", "bcc", a=2.0), - bulk("Li", "bcc", a=3.0), - bulk("Si", "diamond", a=5.43), - bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), - ], - 4, - ), # Mixed sizes - ( - lambda: [ - bulk(element, "fcc", a=4.0) - for element in ("Al", "Cu", "Ni", "Pd", "Pt") * 3 - ], - 15, - ), # Large batch - ], -) -def test_batch_size_variations(systems_func: Callable, expected_count: int) -> None: - """Test batching with different numbers and sizes of systems.""" - systems = systems_func() - - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - assert results["energy"].shape == (expected_count,) - assert results["forces"].shape[0] == sum(len(s) for s in systems) - assert results["forces"].shape[1] == 3 - assert torch.isfinite(results["energy"]).all() - assert torch.isfinite(results["forces"]).all() - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize("compute_stress", [True, False]) -def test_stress_computation(*, compute_stress: bool) -> None: - """Test stress tensor computation.""" - systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] - - model = FairChemModel( - model="uma-s-1p1", - task_name="omat", - device=DEVICE, - compute_stress=compute_stress, - ) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - if compute_stress: - assert "stress" in results - assert results["stress"].shape == (2, 3, 3) - assert torch.isfinite(results["stress"]).all() - else: - assert "stress" not in results - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_device_consistency() -> None: - """Test device consistency between model and data.""" - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) - system = bulk("Si", "diamond", a=5.43) - state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) - - results = model(state) - assert results["energy"].device == DEVICE - assert results["forces"].device == DEVICE - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_empty_batch_error() -> None: - """Test that empty batches raise appropriate errors.""" - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=torch.device("cpu")) - with pytest.raises((ValueError, RuntimeError, IndexError)): - model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32)) - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_load_from_checkpoint_path() -> None: - """Test loading model from a saved checkpoint file path.""" - checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1p1") - loaded_model = FairChemModel( - model=str(checkpoint_path), task_name="omat", device=DEVICE - ) - - # Verify the loaded model works - system = bulk("Si", "diamond", a=5.43) - state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) - results = loaded_model(state) - - assert "energy" in results - assert "forces" in results - assert results["energy"].shape == (1,) - assert torch.isfinite(results["energy"]).all() - assert torch.isfinite(results["forces"]).all() - - test_fairchem_uma_model_outputs = pytest.mark.skipif( get_token() is None, reason="Requires HuggingFace authentication for UMA model access", @@ -240,84 +32,3 @@ def test_load_from_checkpoint_path() -> None: model_fixture_name="eqv2_uma_model_pbc", device=DEVICE, dtype=DTYPE ) ) - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize( - ("charge", "spin"), - [ - (0.0, 0.0), # Neutral, no spin - (1.0, 1.0), # +1 charge, spin=1 (doublet) - (-1.0, 0.0), # -1 charge, no spin (singlet) - (0.0, 2.0), # Neutral, spin=2 (triplet) - ], -) -def test_fairchem_charge_spin(charge: float, spin: float) -> None: - """Test that FairChemModel correctly handles charge and spin from atoms.info.""" - # Create a water molecule - mol = molecule("H2O") - - # Set charge and spin in ASE atoms.info - mol.info["charge"] = charge - mol.info["spin"] = spin - - # Convert to SimState (should extract charge/spin) - state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE) - - # Verify charge/spin were extracted correctly - assert state.charge is not None - assert state.spin is not None - assert state.charge[0].item() == charge - assert state.spin[0].item() == spin - - # Create model with UMA omol task (supports charge/spin for molecules) - model = FairChemModel( - model="uma-s-1p1", - task_name="omol", - device=DEVICE, - ) - - # This should not raise an error - result = model(state) - - # Verify outputs exist - assert "energy" in result - assert result["energy"].shape == (1,) - assert "forces" in result - assert result["forces"].shape == (len(mol), 3) - - # Verify outputs are finite - assert torch.isfinite(result["energy"]).all() - assert torch.isfinite(result["forces"]).all() - - -# TODO: we should perhaps put something like this inside `validate_model_outputs` -# the question is how we can do this with creating a circular dependency -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None: - """Test a single optimization step with FairChemModel. - - This verifies that the model works correctly with optimizers, particularly - that it doesn't have issues with the computational graph (e.g., missing - .detach() calls). - """ - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) - state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE) - - # Initialize FIRE optimizer - opt_state = ts.fire_init(state, model) - initial_positions = opt_state.positions.clone() - _initial_energy = opt_state.energy.item() - - # Run exactly one step - opt_state = ts.fire_step(opt_state, model) - - # Verify positions changed - assert not torch.allclose(opt_state.positions, initial_positions) - # Verify energy is still available and finite - assert torch.isfinite(opt_state.energy).all() - assert isinstance(opt_state.energy.item(), float) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index cdc7fb24..3dc0ea6c 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -1,39 +1,30 @@ -"""FairChem model wrapper for torch-sim. +"""Wrapper for FairChem models in TorchSim. -Provides a TorchSim-compatible interface to FairChem models for computing -energies, forces, and stresses of atomistic systems. +This module re-exports the FairChem package's torch-sim integration for convenient +importing. The actual implementation is maintained in the `fairchem-core` package. -Requires fairchem-core to be installed. +References: + - FairChem Models Package: https://github.com/facebookresearch/fairchem """ -from __future__ import annotations - -import os import traceback -import typing import warnings -from pathlib import Path from typing import Any -import torch - -from torch_sim.models.interface import ModelInterface - try: - from fairchem.core import pretrained_mlip - from fairchem.core.calculate.ase_calculator import UMATask - from fairchem.core.common.utils import setup_imports, setup_logging - from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch + from fairchem.core.calculate.torchsim_interface import FairChemModel except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) + from torch_sim.models.interface import ModelInterface + class FairChemModel(ModelInterface): - """FairChem model wrapper for torch-sim. + """Dummy FairChem model wrapper for torch-sim to enable safe imports. - This class is a placeholder for the FairChemModel class. - It raises an ImportError if FairChem is not installed. + NOTE: This class is a placeholder when `fairchem-core` is not installed. + It raises an ImportError if accessed. """ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: @@ -41,217 +32,4 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err -if typing.TYPE_CHECKING: - from collections.abc import Callable - - from torch_sim.state import SimState - - -class FairChemModel(ModelInterface): - """FairChem model wrapper for computing atomistic properties. - - Wraps FairChem models to compute energies, forces, and stresses. Can be - initialized with a model checkpoint path or pretrained model name. - - Uses the fairchem-core-2.2.0+ predictor API for batch inference. - - Attributes: - predictor: The FairChem predictor for batch inference - task_name (UMATask): Task type for the model - _device (torch.device): Device where computation is performed - _dtype (torch.dtype): Data type used for computation - _compute_stress (bool): Whether to compute stress tensor - implemented_properties (list): Model outputs the model can compute - - Examples: - >>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) - >>> results = model(state) - """ - - def __init__( - self, - model: str | Path, - neighbor_list_fn: Callable | None = None, - *, # force remaining arguments to be keyword-only - model_cache_dir: str | Path | None = None, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - compute_stress: bool = False, - task_name: UMATask | str | None = None, - ) -> None: - """Initialize the FairChem model. - - Args: - model (str | Path): Either a pretrained model name or path to model - checkpoint file. The function will first check if the input matches - a known pretrained model name, then check if it's a valid file path. - neighbor_list_fn (Callable | None): Function to compute neighbor lists - (not currently supported) - model_cache_dir (str | Path | None): Path where to save the model - device (torch.device | None): Device to use for computation. If None, - defaults to CUDA if available, otherwise CPU. - dtype (torch.dtype | None): Data type to use for computation - compute_stress (bool): Whether to compute stress tensor - task_name (UMATask | str | None): Task type for UMA models (optional, - only needed for UMA models) - - Raises: - NotImplementedError: If custom neighbor list function is provided - ValueError: If model is not a known model name or valid file path - """ - setup_imports() - setup_logging() - super().__init__() - - self._dtype = dtype or torch.float32 - self._compute_stress = compute_stress - self._compute_forces = True - self._memory_scales_with = "n_atoms" - - if neighbor_list_fn is not None: - raise NotImplementedError( - "Custom neighbor list is not supported for FairChemModel." - ) - - # Convert Path to string for consistency - if isinstance(model, Path): - model = str(model) - - # Convert task_name to UMATask if it's a string (only for UMA models) - if isinstance(task_name, str): - task_name = UMATask(task_name) - - # Use the efficient predictor API for optimal performance - self._device = device or torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - device_str = str(self._device) - self.task_name = task_name - - # Create efficient batch predictor for fast inference - if model in pretrained_mlip.available_models: - cache_dir: str | Path | None = model_cache_dir - if ( - cache_dir is not None - and isinstance(cache_dir, Path) - and cache_dir.exists() - ): - self.predictor = pretrained_mlip.get_predict_unit( - model, device=device_str, cache_dir=cache_dir - ) - else: - self.predictor = pretrained_mlip.get_predict_unit( - model, device=device_str - ) - elif os.path.isfile(model): - self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str) - else: - raise ValueError( - f"Invalid model name or checkpoint path: {model}. " - f"Available pretrained models are: {pretrained_mlip.available_models}" - ) - - # Determine implemented properties - # This is a simplified approach - in practice you might want to - # inspect the model configuration more carefully - self.implemented_properties = ["energy", "forces"] - if compute_stress: - self.implemented_properties.append("stress") - - @property - def dtype(self) -> torch.dtype: - """Return the data type used by the model.""" - return self._dtype - - @property - def device(self) -> torch.device: - """Return the device where the model is located.""" - return self._device - - def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: - """Compute energies, forces, and other properties. - - Args: - state (SimState): State object containing positions, cells, atomic numbers, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict: Dictionary of model predictions, which may include: - - energy (torch.Tensor): Energy with shape [batch_size] - - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] - """ - sim_state = state - - if sim_state.device != self._device: - sim_state = sim_state.to(self._device) - - # Ensure system_idx has integer dtype (SimState guarantees presence) - if sim_state.system_idx.dtype != torch.int64: - sim_state.system_idx = sim_state.system_idx.to(dtype=torch.int64) - - # Convert SimState to AtomicData objects for efficient batch processing - from ase import Atoms - - n_atoms = torch.bincount(sim_state.system_idx) - atomic_data_list = [] - - pbc_np = sim_state.pbc.detach().cpu().numpy() - - for idx, (n, c) in enumerate( - zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False) - ): - # Extract system data - positions = sim_state.positions[c - n : c].detach().cpu().numpy() - atomic_nums = sim_state.atomic_numbers[c - n : c].detach().cpu().numpy() - cell = ( - sim_state.row_vector_cell[idx].detach().cpu().numpy() - if sim_state.row_vector_cell is not None - else None - ) - - # Create ASE Atoms object first - atoms = Atoms( - numbers=atomic_nums, - positions=positions, - cell=cell, - pbc=pbc_np if cell is not None else False, - ) - - charge = sim_state.charge - spin = sim_state.spin - atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 - atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0 - - # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) - # r_data_keys must be passed for charge/spin to be read from atoms.info - if self.task_name is None: - atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"]) - else: - atomic_data = AtomicData.from_ase( - atoms, task_name=self.task_name, r_data_keys=["charge", "spin"] - ) - atomic_data_list.append(atomic_data) - - # Create batch for efficient inference - batch = atomicdata_list_to_batch(atomic_data_list) - batch = batch.to(self._device) - - # Run efficient batch prediction - predictions = self.predictor.predict(batch) - - # Convert predictions to torch-sim format - results: dict[str, torch.Tensor] = {} - results["energy"] = predictions["energy"].to(dtype=self._dtype) - results["forces"] = predictions["forces"].to(dtype=self._dtype) - - # Handle stress if requested and available - if self._compute_stress and "stress" in predictions: - stress = predictions["stress"].to(dtype=self._dtype) - # Ensure stress has correct shape [batch_size, 3, 3] - if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): - stress = stress.view(-1, 3, 3) - results["stress"] = stress - - return {k: v.detach() for k, v in results.items()} +__all__ = ["FairChemModel"] From ea681c8bdb69a9b21d71642218916f26ac4abdca Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 14:53:41 -0400 Subject: [PATCH 2/3] General Maintenance & reap FairchemV1 (#522) --- .github/workflows/test.yml | 25 -- tests/models/test_fairchem_legacy.py | 120 ------- tests/models/test_nequix.py | 2 +- tests/test_neighbors.py | 2 +- torch_sim/models/fairchem.py | 4 + torch_sim/models/fairchem_legacy.py | 457 ------------------------- torch_sim/models/graphpes_framework.py | 4 + torch_sim/models/mace.py | 4 + torch_sim/models/mattersim.py | 4 + torch_sim/models/nequip_framework.py | 4 + torch_sim/models/nequix.py | 4 + torch_sim/models/orb.py | 4 + torch_sim/models/sevennet.py | 4 + torch_sim/neighbors/__init__.py | 9 +- torch_sim/neighbors/vesin.py | 126 +++---- 15 files changed, 103 insertions(+), 670 deletions(-) delete mode 100644 tests/models/test_fairchem_legacy.py delete mode 100644 torch_sim/models/fairchem_legacy.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e36e6952..2decea17 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,7 +61,6 @@ jobs: - { python: '3.14', resolution: highest } model: - { name: fairchem, test_path: "tests/models/test_fairchem.py" } - - { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } - { name: graphpes, test_path: "tests/models/test_graphpes_framework.py" } - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } @@ -78,7 +77,6 @@ jobs: - version: { python: '3.14', resolution: highest } model: { name: fairchem, test_path: 'tests/models/test_fairchem.py'} - version: { python: '3.14', resolution: highest } - model: { name: fairchem-legacy, test_path: 'tests/models/test_fairchem_legacy.py'} - version: { python: '3.14', resolution: highest } model: { name: nequip, test_path: 'tests/models/test_nequip_framework.py'} runs-on: ${{ matrix.os }} @@ -87,14 +85,6 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - name: Check out fairchem repository - if: ${{ matrix.model.name == 'fairchem-legacy' }} - uses: actions/checkout@v4 - with: - repository: FAIR-Chem/fairchem - path: fairchem-repo - ref: fairchem_core-1.10.0 - - name: Set up Python uses: actions/setup-python@v5 with: @@ -107,22 +97,7 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v6 - - name: Install legacy fairchem repository and dependencies - if: ${{ matrix.model.name == 'fairchem-legacy' }} - run: | - if [ -f fairchem-repo/packages/requirements.txt ]; then - uv pip install -r fairchem-repo/packages/requirements.txt --system - fi - if [ -f fairchem-repo/packages/requirements-optional.txt ]; then - uv pip install -r fairchem-repo/packages/requirements-optional.txt --system - fi - uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system - uv pip install -e "." --no-deps --system - uv pip install "h5py>=3.12.1" "numpy>=1.26,<3" "scipy<1.17.0" "tables>=3.10.2" "torch>=2" "tqdm>=4.67" --system - uv pip install "ase>=3.26" "phonopy>=2.37.0" "psutil>=7.0.0" "pymatgen>=2025.6.14" "pytest-cov>=6" "pytest>=8" --resolution=${{ matrix.version.resolution }} --system - - name: Install torch_sim with model dependencies - if: ${{ matrix.model.name != 'fairchem-legacy' }} run: | # setuptools <82 provides pkg_resources needed by mattersim and fairchem (via torchtnt). # setuptools 82+ removed pkg_resources. Remove pin once those packages migrate. diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py deleted file mode 100644 index f7977c6b..00000000 --- a/tests/models/test_fairchem_legacy.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -import traceback - -import pytest -import torch - -import torch_sim as ts -from tests.conftest import DEVICE -from tests.models.conftest import ( - make_model_calculator_consistency_test, - make_validate_model_outputs_test, -) -from torch_sim.testing import SIMSTATE_BULK_GENERATORS, SIMSTATE_MOLECULE_GENERATORS - - -try: - from fairchem.core import OCPCalculator - from fairchem.core.models.model_registry import model_name_to_local_file - from huggingface_hub.utils._auth import get_token - - from torch_sim.models.fairchem_legacy import FairChemV1Model - -except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip( - f"FairChem not installed: {traceback.format_exc()}", - allow_module_level=True, - ) - - -@pytest.fixture(scope="session") -def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - - -@pytest.fixture -def eqv2_oc20_model_pbc(model_path_oc20: str) -> FairChemV1Model: - return FairChemV1Model(model=model_path_oc20, device=DEVICE, seed=0, pbc=True) - - -@pytest.fixture -def eqv2_oc20_model_non_pbc( - model_path_oc20: str, -) -> FairChemV1Model: - return FairChemV1Model(model=model_path_oc20, device=DEVICE, seed=0, pbc=False) - - -if get_token(): - - @pytest.fixture(scope="session") - def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: - tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" - return model_name_to_local_file(model_name, local_cache=str(tmp_path)) - - @pytest.fixture - def eqv2_omat24_model_pbc( - model_path_omat24: str, - ) -> FairChemV1Model: - return FairChemV1Model(model=model_path_omat24, device=DEVICE, seed=0, pbc=True) - - -@pytest.fixture -def ocp_calculator(model_path_oc20: str) -> OCPCalculator: - return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) - - -test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( - test_name="fairchem_ocp", - model_fixture_name="eqv2_oc20_model_pbc", - calculator_fixture_name="ocp_calculator", - sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, -) - -test_fairchem_non_pbc = make_model_calculator_consistency_test( - test_name="fairchem_non_pbc_benzene", - model_fixture_name="eqv2_oc20_model_non_pbc", - calculator_fixture_name="ocp_calculator", - sim_state_names=tuple(SIMSTATE_MOLECULE_GENERATORS.keys()), - energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - energy_atol=5e-4, - force_rtol=5e-4, - force_atol=5e-4, - stress_rtol=5e-4, - stress_atol=5e-4, -) - - -# Skip this test due to issues with how the older models -# handled supercells (see related issue here: https://github.com/facebookresearch/fairchem/issues/428) - -test_fairchem_ocp_model_outputs = pytest.mark.skipif( - os.environ.get("HF_TOKEN") is None, - reason="Issues in graph construction of older models", -)(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) - - -def test_fairchem_mixed_pbc_init_raises(model_path_oc20: str) -> None: - """Test that initializing FairChemV1Model with mixed PBC raises ValueError.""" - mixed_pbc = torch.tensor([True, False, True], dtype=torch.bool) - with pytest.raises(ValueError, match="FairChemV1Model does not support mixed PBC"): - FairChemV1Model(model=model_path_oc20, device=DEVICE, seed=0, pbc=mixed_pbc) - - -def test_fairchem_mixed_pbc_forward_raises( - eqv2_oc20_model_pbc: FairChemV1Model, si_sim_state: ts.SimState -) -> None: - """Test that calling forward with a SimState that has mixed PBC raises ValueError.""" - mixed_pbc_state = ts.SimState.from_state( - si_sim_state, pbc=torch.tensor([True, False, True], dtype=torch.bool) - ) - with pytest.raises(ValueError, match="FairChemV1Model does not support mixed PBC"): - eqv2_oc20_model_pbc(mixed_pbc_state) diff --git a/tests/models/test_nequix.py b/tests/models/test_nequix.py index 49b004e0..f272e5a5 100644 --- a/tests/models/test_nequix.py +++ b/tests/models/test_nequix.py @@ -16,7 +16,7 @@ from torch_sim.models.nequix import NequixModel except (ImportError, ModuleNotFoundError): pytest.skip( - f"nequix not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"nequix not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index c4e8ae0f..ec6a8bb5 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -171,7 +171,7 @@ def _all_nl_backends() -> list[Any]: not neighbors.VESIN_AVAILABLE, reason="Vesin is not installed" ) _skip_vesin_ts = pytest.mark.skipif( - not neighbors.VESIN_TORCH_AVAILABLE, reason="Vesin is not installed" + not neighbors.VESIN_TORCHSCRIPT_AVAILABLE, reason="Vesin is not installed" ) _skip_alchemiops = pytest.mark.skipif( diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index cdc7fb24..1c95b00d 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -40,6 +40,10 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + if typing.TYPE_CHECKING: from collections.abc import Callable diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py deleted file mode 100644 index 05e452f1..00000000 --- a/torch_sim/models/fairchem_legacy.py +++ /dev/null @@ -1,457 +0,0 @@ -"""Wrapper for Legacy FairChem ecosystem models in TorchSim. - -This module provides a TorchSim wrapper of the FairChem models for computing -energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the FairChem library, integrating it with the torch_sim framework to enable seamless -simulation of atomistic systems with machine learning potentials. - -The FairChemV1Model class adapts FairChem models to the ModelInterface protocol, -allowing them to be used within the broader torch_sim simulation framework. - -Notes: - This implementation requires FairChem < 2.0.0 to be installed and accessible. - It supports various model configurations through configuration files or - pretrained model checkpoints. -""" - -# ruff: noqa: T201 - -from __future__ import annotations - -import copy -import os -import traceback -import typing -import warnings -from pathlib import Path -from types import MappingProxyType -from typing import Any - -import torch - -from torch_sim.models.interface import ModelInterface - - -if typing.TYPE_CHECKING: - from torch_sim.state import SimState - - -def _validate_fairchem_version() -> None: - """Check for a compatible legacy FairChem version.""" - from importlib.metadata import version - - from packaging.version import parse - - fairchem_version = parse(version("fairchem-core")) - if fairchem_version >= parse("2.0.0"): - raise ImportError("FairChem v1.10.0 or lower is required") - - -try: - _validate_fairchem_version() - from fairchem.core.common.registry import registry - from fairchem.core.common.utils import ( - load_config, - setup_imports, - setup_logging, - update_config, - ) - from fairchem.core.models.model_registry import model_name_to_local_file - from torch_geometric.data import Batch, Data - -except ImportError as exc: - warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) - - class FairChemV1Model(ModelInterface): - """FairChem model wrapper for torch_sim. - - This class is a placeholder for the FairChemV1Model class. - It raises an ImportError if FairChem is not installed. - """ - - def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: - """Dummy init for type checking.""" - raise err - - -if typing.TYPE_CHECKING: - from collections.abc import Callable - - -_DTYPE_DICT = { - torch.float16: "float16", - torch.float32: "float32", - torch.float64: "float64", -} - - -class FairChemV1Model(ModelInterface): - """Computes atomistic energies, forces and stresses using a FairChem model. - - This class wraps a FairChem model to compute energies, forces, and stresses for - atomistic systems. It handles model initialization, checkpoint loading, and - provides a forward pass that accepts a SimState object and returns model - predictions. - - The model can be initialized either with a configuration file or a pretrained - checkpoint. It supports various model architectures and configurations supported by - FairChem. - - Attributes: - neighbor_list_fn (Callable | None): Function to compute neighbor lists - config (dict): Complete model configuration dictionary - trainer: FairChem trainer object that contains the model - data_object (Batch): Data object containing system information - implemented_properties (list): Model outputs the model can compute - pbc (bool): Whether periodic boundary conditions are used - _dtype (torch.dtype): Data type used for computation - _compute_stress (bool): Whether to compute stress tensor - _compute_forces (bool): Whether to compute forces - _device (torch.device): Device where computation is performed - _reshaped_props (dict): Properties that need reshaping after computation - - Examples: - >>> model = FairChemV1Model(model="path/to/checkpoint.pt", compute_stress=True) - >>> results = model(state) - - """ - - _reshaped_props = MappingProxyType( - {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} - ) - - def __init__( # noqa: C901, PLR0915 - self, - model: str | Path | None = None, - neighbor_list_fn: Callable | None = None, - *, # force remaining arguments to be keyword-only - config_yml: str | None = None, - local_cache: str | None = None, - trainer: str | None = None, - device: torch.device | None = None, - seed: int | None = None, - dtype: torch.dtype | None = None, - compute_stress: bool = False, - pbc: torch.Tensor | bool = True, - disable_amp: bool = True, - ) -> None: - """Initialize the FairChemV1Model with specified configuration. - - Loads a FairChem model from either a checkpoint path or a configuration file. - Sets up the model parameters, trainer, and configuration for subsequent use - in energy and force calculations. - - Args: - model (str | Path | None): Either a pretrained model name or path to model - checkpoint file. The function will first check if it's a valid file - path, and if not, will attempt to load it as a pretrained model name - (requires local_cache to be set). If None, config_yml must be provided. - neighbor_list_fn (Callable | None): Function to compute neighbor lists - (not currently supported) - config_yml (str | None): Path to configuration YAML file - local_cache (str | None): Path to local model cache directory (required - when using pretrained model names) - trainer (str | None): Name of trainer class to use - device (torch.device | None): Device to use for computation. If None, - defaults to CUDA if available, otherwise CPU. - seed (int | None): Random seed for reproducibility - dtype (torch.dtype | None): Data type to use for computation - compute_stress (bool): Whether to compute stress tensor - pbc (torch.Tensor | bool): Whether to use periodic boundary conditions - disable_amp (bool): Whether to disable AMP - Raises: - NotImplementedError: If custom neighbor list function is provided - ValueError: If stress computation is requested but not supported by model - ValueError: If neither config_yml nor model is provided - ValueError: If model cannot be loaded as file or pretrained model - - Notes: - Either config_yml or model must be provided. The model loads configuration - from the checkpoint if config_yml is not specified. - """ - setup_imports() - setup_logging() - super().__init__() - - self._dtype = dtype or torch.float32 - self._compute_stress = compute_stress - self._compute_forces = True - self._memory_scales_with = "n_atoms" - if isinstance(pbc, bool): - pbc = torch.tensor([pbc] * 3, dtype=torch.bool) - elif not torch.all(pbc == pbc[0]): - raise ValueError( - f"FairChemV1Model does not support mixed PBC (got pbc={pbc.tolist()})" - ) - self.pbc = pbc - - # Process model parameter if provided - if model is not None: - # Convert Path to string for consistency - if isinstance(model, Path): - model = str(model) - - # Determine if model is a file path or a pretrained model name - # First check if it's a valid file path - if not os.path.isfile(model): - # If not a file, try to load as pretrained model name - if local_cache is None: - raise ValueError( - f"Model '{model}' is not a valid file path. " - "If using a pretrained model name, local_cache must be set." - ) - # Attempt to load as pretrained model name - model = model_name_to_local_file( - model_name=model, local_cache=local_cache - ) - - # Either the config path or the checkpoint path needs to be provided - if not config_yml and model is None: - raise ValueError("Either config_yml or model must be provided") - - checkpoint = None - if config_yml is not None: - if isinstance(config_yml, str): - config, duplicates_warning, duplicates_error = load_config(config_yml) - if len(duplicates_warning) > 0: - print( - "Overwritten config parameters from included configs " - f"(non-included parameters take precedence): {duplicates_warning}" - ) - if len(duplicates_error) > 0: - raise ValueError( - "Conflicting (duplicate) parameters in simultaneously " - f"included configs: {duplicates_error}" - ) - else: - config = config_yml - - # Only keeps the train data that might have normalizer values - if isinstance(config["dataset"], list): - config["dataset"] = config["dataset"][0] - elif isinstance(config["dataset"], dict): - config["dataset"] = config["dataset"].get("train", None) - else: - # Loads the config from the checkpoint directly (always on CPU). - if model is None: - raise ValueError("model must be provided when config_yml is not set") - checkpoint = torch.load(model, map_location=torch.device("cpu")) - config = checkpoint["config"] - - if trainer is not None: - config["trainer"] = trainer - else: - config["trainer"] = config.get("trainer", "ocp") - - if "model_attributes" in config: - config["model_attributes"]["name"] = config.pop("model") - config["model"] = config["model_attributes"] - - self.neighbor_list_fn = neighbor_list_fn - - if neighbor_list_fn is None: - # Calculate the edge indices on the fly - config["model"]["otf_graph"] = True - else: - raise NotImplementedError( - "Custom neighbor list is not supported for FairChemV1Model." - ) - - pbc_bool = bool(self.pbc[0].item()) - - if "backbone" in config["model"]: - config["model"]["backbone"]["use_pbc"] = pbc_bool - config["model"]["backbone"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"]["backbone"].update({"dtype": _DTYPE_DICT[dtype]}) - for key in config["model"]["heads"]: - config["model"]["heads"][key].update( - {"dtype": _DTYPE_DICT[dtype]} - ) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - else: - config["model"]["use_pbc"] = pbc_bool - config["model"]["use_pbc_single"] = False - if dtype is not None: - try: - config["model"].update({"dtype": _DTYPE_DICT[dtype]}) - except KeyError: - print( - "WARNING: dtype not found in backbone, using default model dtype" - ) - - ### backwards compatibility with OCP v<2.0 - config = update_config(config) - - self.config = copy.deepcopy(config) - self.config["checkpoint"] = str(model) - del config["dataset"]["src"] - - # Determine if CPU should be used (for the legacy trainer API) - cpu = device is not None and device.type == "cpu" - if device is None: - cpu = not torch.cuda.is_available() - - self.trainer = registry.get_trainer_class(config["trainer"])( - task=config.get("task", {}), - model=config["model"], - dataset=[config["dataset"]], - outputs=config["outputs"], - loss_functions=config["loss_functions"], - evaluation_metrics=config["evaluation_metrics"], - optimizer=config["optim"], - identifier="", - slurm=config.get("slurm", {}), - local_rank=config.get("local_rank", 0), - is_debug=config.get("is_debug", True), - cpu=cpu, - amp=False if dtype is not None else config.get("amp", False), - inference_only=True, - ) - - if dtype is not None: - # Convert model parameters to specified dtype - self.trainer.model = self.trainer.model.to(dtype=self.dtype) - - if model is not None: - self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) - - seed = seed if seed is not None else self.trainer.config["cmd"]["seed"] - if seed is None: - print( - "No seed has been set in model checkpoint or OCPCalculator! Results may " - "not be reproducible on re-run" - ) - else: - self.trainer.set_seed(seed) - - if disable_amp: - self.trainer.scaler = None - - self.implemented_properties = list(self.config["outputs"]) - - self._device = self.trainer.device - - stress_output = "stress" in self.implemented_properties - if not stress_output and compute_stress: - raise NotImplementedError("Stress output not implemented for this model") - - def load_checkpoint( - self, checkpoint_path: str, checkpoint: dict | None = None - ) -> None: - """Load an existing trained model checkpoint. - - Loads model parameters from a checkpoint file or dictionary, - setting the model to inference mode. - - Args: - checkpoint_path (str): Path to the trained model checkpoint file - checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, - this dictionary is used instead of loading from checkpoint_path. - - Notes: - If loading fails, a message is printed but no exception is raised. - """ - try: - self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) - except NotImplementedError: - print("Unable to load checkpoint!") - - def forward( # noqa: C901 - self, state: SimState, **_kwargs: object - ) -> dict[str, torch.Tensor]: - """Perform forward pass to compute energies, forces, and other properties. - - Takes a simulation state and computes the properties implemented by the model, - such as energy, forces, and stresses. - - Args: - state (SimState): State object containing positions, cells, atomic numbers, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict: Dictionary of model predictions, which may include: - - energy (torch.Tensor): Energy with shape [batch_size] - - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], - if compute_stress is True - - Notes: - The state is automatically transferred to the model's device if needed. - All output tensors are detached from the computation graph. - """ - sim_state = state - - if sim_state.device != self._device: - sim_state = sim_state.to(self._device) - - if sim_state.system_idx is None: - sim_state.system_idx = torch.zeros( - sim_state.positions.shape[0], dtype=torch.int - ) - - # Extract uniform PBC value from state (validate it's uniform) - if isinstance(sim_state.pbc, torch.Tensor): - if not torch.all(sim_state.pbc == sim_state.pbc[0]): - raise ValueError( - "FairChemV1Model does not support mixed PBC " - f"(got state.pbc={sim_state.pbc.tolist()})" - ) - state_pbc_bool = bool(sim_state.pbc[0].item()) - else: - state_pbc_bool = bool(sim_state.pbc) - - model_pbc_bool = bool(self.pbc[0].item()) - - if model_pbc_bool != state_pbc_bool: - raise ValueError( - f"PBC mismatch: model has pbc={model_pbc_bool}, " - f"but state has pbc={state_pbc_bool}. " - "FairChemV1Model requires model and state PBC to match." - ) - - natoms = torch.bincount(sim_state.system_idx) - fixed = torch.zeros( - (sim_state.system_idx.size(0), int(natoms.sum().item())), dtype=torch.int - ) - data_list = [] - for idx, (n, c) in enumerate( - zip(natoms, torch.cumsum(natoms, dim=0), strict=False) - ): - data_list.append( - Data( - pos=sim_state.positions[c - n : c].detach().clone(), - cell=sim_state.row_vector_cell[idx, None].detach().clone(), - atomic_numbers=sim_state.atomic_numbers[c - n : c].detach().clone(), - fixed=fixed[c - n : c].detach().clone(), - natoms=n, - pbc=sim_state.pbc, - ) - ) - self.data_object = Batch.from_data_list(data_list) - - if self.dtype is not None: - self.data_object.pos = self.data_object.pos.to(self.dtype) - self.data_object.cell = self.data_object.cell.to(self.dtype) - - predictions = self.trainer.predict( - self.data_object, per_image=False, disable_tqdm=True - ) - - results = {} - - for key in predictions: - _pred = predictions[key] - if key in self._reshaped_props: - _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() - results[key] = _pred - - results["energy"] = results["energy"].squeeze(dim=1) - if results.get("stress") is not None and len(results["stress"].shape) == 2: - results["stress"] = results["stress"].unsqueeze(dim=0) - return {k: v.detach() for k, v in results.items()} diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index 800fe819..16c5d82c 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -45,6 +45,10 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + class AtomicGraph: # noqa: D101 def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D107,ARG002 raise ImportError("graph_pes must be installed to use this model.") diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index e90392c6..e13e3ce8 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -48,6 +48,10 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + def to_one_hot( indices: torch.Tensor, num_classes: int, dtype: torch.dtype diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index ce2a68d5..d5b1e6d1 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -34,6 +34,10 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + if TYPE_CHECKING: from mattersim.forcefield import Potential diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py index ac493abe..cdb846bd 100644 --- a/torch_sim/models/nequip_framework.py +++ b/torch_sim/models/nequip_framework.py @@ -43,6 +43,10 @@ def __init__( """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + @classmethod def from_compiled_model(cls, _path: Any, *_args: Any, **_kwargs: Any) -> Self: """Dummy classmethod for type checking when NequIP is not installed.""" diff --git a/torch_sim/models/nequix.py b/torch_sim/models/nequix.py index 38f5cbde..6b9bc557 100644 --- a/torch_sim/models/nequix.py +++ b/torch_sim/models/nequix.py @@ -37,6 +37,10 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + @classmethod def from_compiled_model(cls, _path: Any, *_args: Any, **_kwargs: Any) -> Self: """Dummy classmethod for type checking when nequix is not installed.""" diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 47417d08..80587a37 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -42,5 +42,9 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + __all__ = ["OrbModel"] diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index efefa5e9..0777ddab 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -31,5 +31,9 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err + def forward(self, *_args: Any, **_kwargs: Any) -> Any: + """Unreachable — __init__ always raises.""" + raise NotImplementedError + __all__ = ["SevenNetModel"] diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index 3c123db2..10b5a1db 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -24,7 +24,7 @@ from torch_sim.neighbors.torch_nl import strict_nl, torch_nl_linked_cell, torch_nl_n2 from torch_sim.neighbors.vesin import ( VESIN_AVAILABLE, - VESIN_TORCH_AVAILABLE, + VESIN_TORCHSCRIPT_AVAILABLE, vesin_nl, vesin_nl_ts, ) @@ -67,8 +67,9 @@ def _normalize_inputs( # Set default neighbor list based on what's available (priority order) if ALCHEMIOPS_AVAILABLE: # Alchemiops is fastest on NVIDIA GPUs + # TODO: why default to n2? we should document the cross-over point default_batched_nl = alchemiops_nl_n2 -elif VESIN_TORCH_AVAILABLE: +elif VESIN_TORCHSCRIPT_AVAILABLE: default_batched_nl = vesin_nl_ts elif VESIN_AVAILABLE: default_batched_nl = vesin_nl @@ -120,7 +121,7 @@ def torchsim_nl( positions, cell, pbc, cutoff, system_idx, self_interaction ) - if VESIN_TORCH_AVAILABLE: + if VESIN_TORCHSCRIPT_AVAILABLE: return vesin_nl_ts(positions, cell, pbc, cutoff, system_idx, self_interaction) if VESIN_AVAILABLE: @@ -135,7 +136,7 @@ def torchsim_nl( "ALCHEMIOPS_AVAILABLE", "ALCHEMIOPS_TORCH_AVAILABLE", "VESIN_AVAILABLE", - "VESIN_TORCH_AVAILABLE", + "VESIN_TORCHSCRIPT_AVAILABLE", "alchemiops_nl_cell_list", "alchemiops_nl_n2", "strict_nl", diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index 33aa8155..009fe9bb 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -12,36 +12,37 @@ try: from vesin import NeighborList as VesinNeighborList except ImportError: - VesinNeighborList = None # ty:ignore[invalid-assignment] + VesinNeighborList = None # type: ignore[assignment] + + try: from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: VesinNeighborListTorch = None # ty:ignore[invalid-assignment] VESIN_AVAILABLE = VesinNeighborList is not None -VESIN_TORCH_AVAILABLE = VesinNeighborListTorch is not None - +VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None if VESIN_AVAILABLE: - def vesin_nl_ts( + def vesin_nl( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float | torch.Tensor, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute neighbor lists using TorchScript-compatible Vesin. + """Compute neighbor lists using the standard Vesin implementation. - This function provides a TorchScript-compatible interface to the Vesin - neighbor list algorithm using VesinNeighborListTorch. + This function provides an interface to the standard Vesin neighbor list + algorithm using VesinNeighborList. Args: positions: Atomic positions tensor [n_atoms, 3] cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] pbc: Boolean tensor [n_systems, 3] or [3] - cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors + cutoff: Maximum distance for considering atoms as neighbors system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False @@ -55,26 +56,25 @@ def vesin_nl_ts( >>> # Single system >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) >>> system_idx = torch.zeros(2, dtype=torch.long) - >>> mapping, sys_map, shifts = vesin_nl_ts( + >>> mapping, sys_map, shifts = vesin_nl( ... positions, cell, pbc, cutoff, system_idx ... ) Notes: - - Uses VesinNeighborListTorch for TorchScript compatibility + - Uses standard VesinNeighborList implementation - Requires CPU tensors in float64 precision internally - Returns tensors on the same device as input with original precision - For non-periodic systems, shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs References: - https://github.com/Luthaf/vesin + - https://github.com/Luthaf/vesin """ from torch_sim.neighbors import _normalize_inputs - if VesinNeighborListTorch is None: + if VesinNeighborList is None: raise RuntimeError( - "vesin.torch is not available. " - "Install it with: [uv] pip install vesin[torch]" + "vesin is not installed. Install it with: [uv] pip install vesin" ) device = positions.device dtype = positions.dtype @@ -94,27 +94,32 @@ def vesin_nl_ts( if n_atoms_in_system == 0: continue - # Calculate neighbor list for this system - neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) - # Get the cell for this system cell_sys = cell[sys_idx] - # Convert tensors to CPU and float64 properly - positions_cpu = positions[system_mask].cpu().to(dtype=torch.float64) - cell_cpu = cell_sys.cpu().to(dtype=torch.float64) - periodic_cpu = pbc[sys_idx].to(dtype=torch.bool).cpu() + # Calculate neighbor list for this system + neighbor_list_fn = VesinNeighborList( + (float(cutoff)), full_list=True, sorted=False + ) - # Only works on CPU and requires float64 + # Convert tensors to CPU and float64 without gradients + positions_cpu = positions[system_mask].detach().cpu().to(dtype=torch.float64) + cell_cpu = cell_sys.detach().cpu().to(dtype=torch.float64) + periodic_cpu = pbc[sys_idx].detach().to(dtype=torch.bool).cpu() + + # Only works on CPU and returns numpy arrays i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, periodic=periodic_cpu, quantities="ijS", ) - - edge_idx = torch.stack((i, j), dim=0).to(dtype=torch.long, device=device) - shifts = S.to(dtype=dtype, device=device) + i, j = ( + torch.tensor(i, dtype=torch.long, device=device), + torch.tensor(j, dtype=torch.long, device=device), + ) + edge_idx = torch.stack((i, j), dim=0) + shifts = torch.tensor(S, dtype=dtype, device=device) # Adjust indices for the global atom indexing edge_idx = edge_idx + offset @@ -152,24 +157,36 @@ def vesin_nl_ts( return mapping, system_mapping, shifts_idx +else: + def vesin_nl( + *args, # noqa: ARG001 + **kwargs, # noqa: ARG001 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Stub function when Vesin is not available.""" + raise ImportError("Vesin is not installed. Install it with: pip install vesin") + + +if VESIN_TORCHSCRIPT_AVAILABLE: + + def vesin_nl_ts( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: float | torch.Tensor, + cutoff: torch.Tensor, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute neighbor lists using the standard Vesin implementation. + """Compute neighbor lists using TorchScript-compatible Vesin. - This function provides an interface to the standard Vesin neighbor list - algorithm using VesinNeighborList. + This function provides a TorchScript-compatible interface to the Vesin + neighbor list algorithm using VesinNeighborListTorch. Args: positions: Atomic positions tensor [n_atoms, 3] cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] pbc: Boolean tensor [n_systems, 3] or [3] - cutoff: Maximum distance for considering atoms as neighbors + cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False @@ -183,26 +200,24 @@ def vesin_nl( >>> # Single system >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) >>> system_idx = torch.zeros(2, dtype=torch.long) - >>> mapping, sys_map, shifts = vesin_nl( + >>> mapping, sys_map, shifts = vesin_nl_ts( ... positions, cell, pbc, cutoff, system_idx ... ) Notes: - - Uses standard VesinNeighborList implementation + - Uses VesinNeighborListTorch for TorchScript compatibility - Requires CPU tensors in float64 precision internally - Returns tensors on the same device as input with original precision - For non-periodic systems, shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs References: - - https://github.com/Luthaf/vesin + https://github.com/Luthaf/vesin """ from torch_sim.neighbors import _normalize_inputs - if VesinNeighborList is None: - raise RuntimeError( - "vesin is not installed. Install it with: [uv] pip install vesin" - ) + if VesinNeighborListTorch is None: + raise RuntimeError("vesin[torch] package is not installed") device = positions.device dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 @@ -221,32 +236,27 @@ def vesin_nl( if n_atoms_in_system == 0: continue + # Calculate neighbor list for this system + neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) + # Get the cell for this system cell_sys = cell[sys_idx] - # Calculate neighbor list for this system - neighbor_list_fn = VesinNeighborList( - (float(cutoff)), full_list=True, sorted=False - ) - - # Convert tensors to CPU and float64 without gradients - positions_cpu = positions[system_mask].detach().cpu().to(dtype=torch.float64) - cell_cpu = cell_sys.detach().cpu().to(dtype=torch.float64) - periodic_cpu = pbc[sys_idx].detach().to(dtype=torch.bool).cpu() + # Convert tensors to CPU and float64 properly + positions_cpu = positions[system_mask].cpu().to(dtype=torch.float64) + cell_cpu = cell_sys.cpu().to(dtype=torch.float64) + periodic_cpu = pbc[sys_idx].to(dtype=torch.bool).cpu() - # Only works on CPU and returns numpy arrays + # Only works on CPU and requires float64 i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, periodic=periodic_cpu, quantities="ijS", ) - i, j = ( - torch.tensor(i, dtype=torch.long, device=device), - torch.tensor(j, dtype=torch.long, device=device), - ) - edge_idx = torch.stack((i, j), dim=0) - shifts = torch.tensor(S, dtype=dtype, device=device) + + edge_idx = torch.stack((i, j), dim=0).to(dtype=torch.long, device=device) + shifts = S.to(dtype=dtype, device=device) # Adjust indices for the global atom indexing edge_idx = edge_idx + offset @@ -283,7 +293,6 @@ def vesin_nl( system_mapping = torch.cat([system_mapping, self_sys_mapping], dim=0) return mapping, system_mapping, shifts_idx - else: # Provide stub functions that raise informative errors def vesin_nl_ts( @@ -292,10 +301,3 @@ def vesin_nl_ts( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Stub function when Vesin is not available.""" raise ImportError("Vesin is not installed. Install it with: pip install vesin") - - def vesin_nl( - *args, # noqa: ARG001 - **kwargs, # noqa: ARG001 - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Stub function when Vesin is not available.""" - raise ImportError("Vesin is not installed. Install it with: pip install vesin") From 8eaa56c696d73d84f8169b99f65c7956ef5591f5 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 16:16:45 -0400 Subject: [PATCH 3/3] run 3.14 tests for fairchem --- .github/workflows/test.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2decea17..c5dbbd1f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -74,9 +74,6 @@ jobs: exclude: - version: { python: '3.14', resolution: highest } model: { name: orb, test_path: 'tests/models/test_orb.py' } - - version: { python: '3.14', resolution: highest } - model: { name: fairchem, test_path: 'tests/models/test_fairchem.py'} - - version: { python: '3.14', resolution: highest } - version: { python: '3.14', resolution: highest } model: { name: nequip, test_path: 'tests/models/test_nequip_framework.py'} runs-on: ${{ matrix.os }}