diff --git a/pyproject.toml b/pyproject.toml index 90de3c4a..56c1da91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,8 @@ orb = ["orb-models>=0.6.0"] sevenn = ["sevenn[torchsim]>=0.12.1"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] nequip = ["nequip>=0.17.0"] +fairchem = ["fairchem-core @ git+https://github.com/facebookresearch/fairchem.git@main#subdirectory=packages/fairchem-core"] nequix = ["nequix[torch-sim]>=0.4.5"] -fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index d259e2ec..7937c692 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,23 +1,14 @@ import traceback import pytest -import torch -import torch_sim as ts from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test try: - from collections.abc import Callable - - from ase.build import bulk, fcc100, molecule - from fairchem.core.calculate.pretrained_mlip import ( - pretrained_checkpoint_path_from_name, - ) from huggingface_hub.utils._auth import get_token - import torch_sim as ts from torch_sim.models.fairchem import FairChemModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): @@ -33,205 +24,6 @@ def eqv2_uma_model_pbc() -> FairChemModel: return FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"]) -def test_task_initialization(task_name: str) -> None: - """Test that different UMA task names work correctly.""" - model = FairChemModel( - model="uma-s-1p1", task_name=task_name, device=torch.device("cpu") - ) - assert model.task_name - assert str(model.task_name.value) == task_name - assert hasattr(model, "predictor") - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize( - ("task_name", "systems_func"), - [ - ( - "omat", - lambda: [ - bulk("Si", "diamond", a=5.43), - bulk("Al", "fcc", a=4.05), - bulk("Fe", "bcc", a=2.87), - bulk("Cu", "fcc", a=3.61), - ], - ), - ( - "omol", - lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")], - ), - ], -) -def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None: - """Test batching multiple systems with the same task.""" - systems = systems_func() - - # Add molecular properties for molecules - if task_name == "omol": - for mol in systems: - mol.info |= {"charge": 0, "spin": 1} - - model = FairChemModel(model="uma-s-1p1", task_name=task_name, device=DEVICE) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - # Check batch dimensions - assert results["energy"].shape == (4,) - assert results["forces"].shape[0] == sum(len(s) for s in systems) - assert results["forces"].shape[1] == 3 - - # Check that different systems have different energies - energies = results["energy"] - uniq_energies = torch.unique(energies, dim=0) - assert len(uniq_energies) > 1, "Different systems should have different energies" - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_heterogeneous_tasks() -> None: - """Test different task types work with appropriate systems.""" - # Test molecule, material, and catalysis systems separately - test_cases = [ - ("omol", [molecule("H2O")]), - ("omat", [bulk("Pt", cubic=True)]), - ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]), - ] - - for task_name, systems in test_cases: - if task_name == "omol": - systems[0].info |= {"charge": 0, "spin": 1} - - model = FairChemModel( - model="uma-s-1p1", - task_name=task_name, - device=DEVICE, - ) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - assert results["energy"].shape[0] == 1 - assert results["forces"].dim() == 2 - assert results["forces"].shape[1] == 3 - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize( - ("systems_func", "expected_count"), - [ - (lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system - ( - lambda: [ - bulk("H", "bcc", a=2.0), - bulk("Li", "bcc", a=3.0), - bulk("Si", "diamond", a=5.43), - bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), - ], - 4, - ), # Mixed sizes - ( - lambda: [ - bulk(element, "fcc", a=4.0) - for element in ("Al", "Cu", "Ni", "Pd", "Pt") * 3 - ], - 15, - ), # Large batch - ], -) -def test_batch_size_variations(systems_func: Callable, expected_count: int) -> None: - """Test batching with different numbers and sizes of systems.""" - systems = systems_func() - - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - assert results["energy"].shape == (expected_count,) - assert results["forces"].shape[0] == sum(len(s) for s in systems) - assert results["forces"].shape[1] == 3 - assert torch.isfinite(results["energy"]).all() - assert torch.isfinite(results["forces"]).all() - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize("compute_stress", [True, False]) -def test_stress_computation(*, compute_stress: bool) -> None: - """Test stress tensor computation.""" - systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] - - model = FairChemModel( - model="uma-s-1p1", - task_name="omat", - device=DEVICE, - compute_stress=compute_stress, - ) - state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) - results = model(state) - - if compute_stress: - assert "stress" in results - assert results["stress"].shape == (2, 3, 3) - assert torch.isfinite(results["stress"]).all() - else: - assert "stress" not in results - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_device_consistency() -> None: - """Test device consistency between model and data.""" - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) - system = bulk("Si", "diamond", a=5.43) - state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) - - results = model(state) - assert results["energy"].device == DEVICE - assert results["forces"].device == DEVICE - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_empty_batch_error() -> None: - """Test that empty batches raise appropriate errors.""" - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=torch.device("cpu")) - with pytest.raises((ValueError, RuntimeError, IndexError)): - model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32)) - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_load_from_checkpoint_path() -> None: - """Test loading model from a saved checkpoint file path.""" - checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1p1") - loaded_model = FairChemModel( - model=str(checkpoint_path), task_name="omat", device=DEVICE - ) - - # Verify the loaded model works - system = bulk("Si", "diamond", a=5.43) - state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) - results = loaded_model(state) - - assert "energy" in results - assert "forces" in results - assert results["energy"].shape == (1,) - assert torch.isfinite(results["energy"]).all() - assert torch.isfinite(results["forces"]).all() - - test_fairchem_uma_model_outputs = pytest.mark.skipif( get_token() is None, reason="Requires HuggingFace authentication for UMA model access", @@ -240,84 +32,3 @@ def test_load_from_checkpoint_path() -> None: model_fixture_name="eqv2_uma_model_pbc", device=DEVICE, dtype=DTYPE ) ) - - -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -@pytest.mark.parametrize( - ("charge", "spin"), - [ - (0.0, 0.0), # Neutral, no spin - (1.0, 1.0), # +1 charge, spin=1 (doublet) - (-1.0, 0.0), # -1 charge, no spin (singlet) - (0.0, 2.0), # Neutral, spin=2 (triplet) - ], -) -def test_fairchem_charge_spin(charge: float, spin: float) -> None: - """Test that FairChemModel correctly handles charge and spin from atoms.info.""" - # Create a water molecule - mol = molecule("H2O") - - # Set charge and spin in ASE atoms.info - mol.info["charge"] = charge - mol.info["spin"] = spin - - # Convert to SimState (should extract charge/spin) - state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE) - - # Verify charge/spin were extracted correctly - assert state.charge is not None - assert state.spin is not None - assert state.charge[0].item() == charge - assert state.spin[0].item() == spin - - # Create model with UMA omol task (supports charge/spin for molecules) - model = FairChemModel( - model="uma-s-1p1", - task_name="omol", - device=DEVICE, - ) - - # This should not raise an error - result = model(state) - - # Verify outputs exist - assert "energy" in result - assert result["energy"].shape == (1,) - assert "forces" in result - assert result["forces"].shape == (len(mol), 3) - - # Verify outputs are finite - assert torch.isfinite(result["energy"]).all() - assert torch.isfinite(result["forces"]).all() - - -# TODO: we should perhaps put something like this inside `validate_model_outputs` -# the question is how we can do this with creating a circular dependency -@pytest.mark.skipif( - get_token() is None, reason="Requires HuggingFace authentication for UMA model access" -) -def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None: - """Test a single optimization step with FairChemModel. - - This verifies that the model works correctly with optimizers, particularly - that it doesn't have issues with the computational graph (e.g., missing - .detach() calls). - """ - model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) - state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE) - - # Initialize FIRE optimizer - opt_state = ts.fire_init(state, model) - initial_positions = opt_state.positions.clone() - _initial_energy = opt_state.energy.item() - - # Run exactly one step - opt_state = ts.fire_step(opt_state, model) - - # Verify positions changed - assert not torch.allclose(opt_state.positions, initial_positions) - # Verify energy is still available and finite - assert torch.isfinite(opt_state.energy).all() - assert isinstance(opt_state.energy.item(), float) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1c95b00d..7eb341df 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -1,39 +1,30 @@ -"""FairChem model wrapper for torch-sim. +"""Wrapper for FairChem models in TorchSim. -Provides a TorchSim-compatible interface to FairChem models for computing -energies, forces, and stresses of atomistic systems. +This module re-exports the FairChem package's torch-sim integration for convenient +importing. The actual implementation is maintained in the `fairchem-core` package. -Requires fairchem-core to be installed. +References: + - FairChem Models Package: https://github.com/facebookresearch/fairchem """ -from __future__ import annotations - -import os import traceback -import typing import warnings -from pathlib import Path from typing import Any -import torch - -from torch_sim.models.interface import ModelInterface - try: - from fairchem.core import pretrained_mlip - from fairchem.core.calculate.ase_calculator import UMATask - from fairchem.core.common.utils import setup_imports, setup_logging - from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch + from fairchem.core.calculate.torchsim_interface import FairChemModel except ImportError as exc: warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) + from torch_sim.models.interface import ModelInterface + class FairChemModel(ModelInterface): - """FairChem model wrapper for torch-sim. + """Dummy FairChem model wrapper for torch-sim to enable safe imports. - This class is a placeholder for the FairChemModel class. - It raises an ImportError if FairChem is not installed. + NOTE: This class is a placeholder when `fairchem-core` is not installed. + It raises an ImportError if accessed. """ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: @@ -45,217 +36,4 @@ def forward(self, *_args: Any, **_kwargs: Any) -> Any: raise NotImplementedError -if typing.TYPE_CHECKING: - from collections.abc import Callable - - from torch_sim.state import SimState - - -class FairChemModel(ModelInterface): - """FairChem model wrapper for computing atomistic properties. - - Wraps FairChem models to compute energies, forces, and stresses. Can be - initialized with a model checkpoint path or pretrained model name. - - Uses the fairchem-core-2.2.0+ predictor API for batch inference. - - Attributes: - predictor: The FairChem predictor for batch inference - task_name (UMATask): Task type for the model - _device (torch.device): Device where computation is performed - _dtype (torch.dtype): Data type used for computation - _compute_stress (bool): Whether to compute stress tensor - implemented_properties (list): Model outputs the model can compute - - Examples: - >>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) - >>> results = model(state) - """ - - def __init__( - self, - model: str | Path, - neighbor_list_fn: Callable | None = None, - *, # force remaining arguments to be keyword-only - model_cache_dir: str | Path | None = None, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - compute_stress: bool = False, - task_name: UMATask | str | None = None, - ) -> None: - """Initialize the FairChem model. - - Args: - model (str | Path): Either a pretrained model name or path to model - checkpoint file. The function will first check if the input matches - a known pretrained model name, then check if it's a valid file path. - neighbor_list_fn (Callable | None): Function to compute neighbor lists - (not currently supported) - model_cache_dir (str | Path | None): Path where to save the model - device (torch.device | None): Device to use for computation. If None, - defaults to CUDA if available, otherwise CPU. - dtype (torch.dtype | None): Data type to use for computation - compute_stress (bool): Whether to compute stress tensor - task_name (UMATask | str | None): Task type for UMA models (optional, - only needed for UMA models) - - Raises: - NotImplementedError: If custom neighbor list function is provided - ValueError: If model is not a known model name or valid file path - """ - setup_imports() - setup_logging() - super().__init__() - - self._dtype = dtype or torch.float32 - self._compute_stress = compute_stress - self._compute_forces = True - self._memory_scales_with = "n_atoms" - - if neighbor_list_fn is not None: - raise NotImplementedError( - "Custom neighbor list is not supported for FairChemModel." - ) - - # Convert Path to string for consistency - if isinstance(model, Path): - model = str(model) - - # Convert task_name to UMATask if it's a string (only for UMA models) - if isinstance(task_name, str): - task_name = UMATask(task_name) - - # Use the efficient predictor API for optimal performance - self._device = device or torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - device_str = str(self._device) - self.task_name = task_name - - # Create efficient batch predictor for fast inference - if model in pretrained_mlip.available_models: - cache_dir: str | Path | None = model_cache_dir - if ( - cache_dir is not None - and isinstance(cache_dir, Path) - and cache_dir.exists() - ): - self.predictor = pretrained_mlip.get_predict_unit( - model, device=device_str, cache_dir=cache_dir - ) - else: - self.predictor = pretrained_mlip.get_predict_unit( - model, device=device_str - ) - elif os.path.isfile(model): - self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str) - else: - raise ValueError( - f"Invalid model name or checkpoint path: {model}. " - f"Available pretrained models are: {pretrained_mlip.available_models}" - ) - - # Determine implemented properties - # This is a simplified approach - in practice you might want to - # inspect the model configuration more carefully - self.implemented_properties = ["energy", "forces"] - if compute_stress: - self.implemented_properties.append("stress") - - @property - def dtype(self) -> torch.dtype: - """Return the data type used by the model.""" - return self._dtype - - @property - def device(self) -> torch.device: - """Return the device where the model is located.""" - return self._device - - def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: - """Compute energies, forces, and other properties. - - Args: - state (SimState): State object containing positions, cells, atomic numbers, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict: Dictionary of model predictions, which may include: - - energy (torch.Tensor): Energy with shape [batch_size] - - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] - """ - sim_state = state - - if sim_state.device != self._device: - sim_state = sim_state.to(self._device) - - # Ensure system_idx has integer dtype (SimState guarantees presence) - if sim_state.system_idx.dtype != torch.int64: - sim_state.system_idx = sim_state.system_idx.to(dtype=torch.int64) - - # Convert SimState to AtomicData objects for efficient batch processing - from ase import Atoms - - n_atoms = torch.bincount(sim_state.system_idx) - atomic_data_list = [] - - pbc_np = sim_state.pbc.detach().cpu().numpy() - - for idx, (n, c) in enumerate( - zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False) - ): - # Extract system data - positions = sim_state.positions[c - n : c].detach().cpu().numpy() - atomic_nums = sim_state.atomic_numbers[c - n : c].detach().cpu().numpy() - cell = ( - sim_state.row_vector_cell[idx].detach().cpu().numpy() - if sim_state.row_vector_cell is not None - else None - ) - - # Create ASE Atoms object first - atoms = Atoms( - numbers=atomic_nums, - positions=positions, - cell=cell, - pbc=pbc_np if cell is not None else False, - ) - - charge = sim_state.charge - spin = sim_state.spin - atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 - atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0 - - # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) - # r_data_keys must be passed for charge/spin to be read from atoms.info - if self.task_name is None: - atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"]) - else: - atomic_data = AtomicData.from_ase( - atoms, task_name=self.task_name, r_data_keys=["charge", "spin"] - ) - atomic_data_list.append(atomic_data) - - # Create batch for efficient inference - batch = atomicdata_list_to_batch(atomic_data_list) - batch = batch.to(self._device) - - # Run efficient batch prediction - predictions = self.predictor.predict(batch) - - # Convert predictions to torch-sim format - results: dict[str, torch.Tensor] = {} - results["energy"] = predictions["energy"].to(dtype=self._dtype) - results["forces"] = predictions["forces"].to(dtype=self._dtype) - - # Handle stress if requested and available - if self._compute_stress and "stress" in predictions: - stress = predictions["stress"].to(dtype=self._dtype) - # Ensure stress has correct shape [batch_size, 3, 3] - if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): - stress = stress.view(-1, 3, 3) - results["stress"] = stress - - return {k: v.detach() for k, v in results.items()} +__all__ = ["FairChemModel"]