From df440c5c4fc4375d57f88c9cd120760f2483e2b2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 19:46:29 -0400 Subject: [PATCH 1/5] fix: ty now doesn't complain but a bunch of tests fail. --- tests/test_extras.py | 152 +++++++++++++++ tests/test_io.py | 84 ++++++++ torch_sim/integrators/md.py | 1 + torch_sim/integrators/npt.py | 19 +- torch_sim/integrators/nve.py | 5 +- torch_sim/integrators/nvt.py | 13 +- torch_sim/io.py | 64 ++++-- torch_sim/models/interface.py | 48 ++++- torch_sim/models/mace.py | 7 + torch_sim/monte_carlo.py | 5 +- torch_sim/neighbors/vesin.py | 4 +- torch_sim/optimizers/bfgs.py | 95 +++------ torch_sim/optimizers/fire.py | 9 +- torch_sim/optimizers/gradient_descent.py | 3 + torch_sim/optimizers/lbfgs.py | 40 ++-- torch_sim/runners.py | 21 +- torch_sim/state.py | 238 ++++++++++++++++++++--- 17 files changed, 661 insertions(+), 147 deletions(-) create mode 100644 tests/test_extras.py diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 00000000..75947706 --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,152 @@ +import pytest +import torch + +import torch_sim as ts + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +class TestExtras: + def test_system_extras_construction(self): + """Extras can be passed at construction time.""" + field = torch.randn(1, 3) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + external_E_field=field, + ) + assert torch.equal(state.external_E_field, field) + + def test_atom_extras_construction(self): + """Per-atom extras work at construction time.""" + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + assert torch.equal(state.tags, tags) + + def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState): + with pytest.raises(AttributeError, match="nonexistent_key"): + _ = cu_sim_state.nonexistent_key + + def test_post_init_validation_rejects_bad_shape(self): + with pytest.raises(ValueError, match="leading dim must be n_systems"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"bad": torch.randn(5, 3)}, + ) + + def test_construction_extras_cannot_shadow(self): + # Post-init validation should also catch shadowing during construction + with pytest.raises(ValueError, match="shadows an existing attribute"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"cell": torch.zeros(1, 3)}, + ) + + # store_model_extras + def test_store_model_extras_canonical_keys_not_stored( + self, si_double_sim_state: ts.SimState + ): + """Canonical keys (energy, forces, stress) must not land in extras.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "forces": torch.randn(state.n_atoms, 3), + "stress": torch.randn(state.n_systems, 3, 3), + } + ) + assert not state._system_extras # noqa: SLF001 + assert not state._atom_extras # noqa: SLF001 + + def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_systems go into system_extras.""" + state = si_double_sim_state.clone() + dipole = torch.randn(state.n_systems, 3) + state.store_model_extras( + {"energy": torch.randn(state.n_systems), "dipole": dipole} + ) + assert torch.equal(state.dipole, dipole) + + def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_atoms go into atom_extras.""" + state = si_double_sim_state.clone() + charges = torch.randn(state.n_atoms) + density = torch.randn(state.n_atoms, 8) + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "charges": charges, + "density_coefficients": density, + } + ) + assert torch.equal(state.charges, charges) + assert state.density_coefficients.shape == (state.n_atoms, 8) + + def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState): + """0-d tensors and non-Tensor values are silently ignored.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "scalar": torch.tensor(3.14), + "string": "not a tensor", + } + ) + assert not state.has_extras("scalar") + assert not state.has_extras("string") + + +def test_system_extras_atoms_roundtrip(): + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])}, + ) + atoms_list = state.to_atoms() + assert "external_E_field" in atoms_list[0].info + restored = ts.io.atoms_to_state( + atoms_list, + system_extras_keys=["external_E_field"], + ) + assert torch.allclose(restored.external_E_field, state.external_E_field) + + +def test_atom_extras_atoms_roundtrip(): + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + atoms_list = state.to_atoms() + assert "tags" in atoms_list[0].arrays + restored = ts.io.atoms_to_state( + atoms_list, + atom_extras_keys=["tags"], + ) + assert torch.allclose(restored.tags, state.tags) diff --git a/tests/test_io.py b/tests/test_io.py index 2bb4f017..8e1de1ac 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -5,6 +5,7 @@ import pytest import torch from ase import Atoms +from ase.build import molecule from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure @@ -91,6 +92,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: ) +@pytest.mark.parametrize( + ("charge", "spin", "expected_charge", "expected_spin"), + [ + (1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin + (0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin + (None, None, 0.0, 0.0), # No charge/spin set, defaults to zero + ], +) +def test_atoms_to_state_with_charge_spin( + charge: float | None, + spin: float | None, + expected_charge: float, + expected_spin: float, +) -> None: + """Test conversion from ASE Atoms with charge and spin to state tensors.""" + mol = molecule("H2O") + if charge is not None: + mol.info["charge"] = charge + if spin is not None: + mol.info["spin"] = spin + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (1,) + assert state.spin.shape == (1,) + assert state.charge[0].item() == expected_charge + assert state.spin[0].item() == expected_spin + + +def test_multiple_atoms_to_state_with_charge_spin() -> None: + """Test conversion from multiple ASE Atoms with different charge/spin values.""" + mol1 = molecule("H2O") + mol1.info["charge"] = 1.0 + mol1.info["spin"] = 1.0 + + mol2 = molecule("CH4") + mol2.info["charge"] = -1.0 + mol2.info["spin"] = 0.0 + + mol3 = molecule("NH3") + mol3.info["charge"] = 0.0 + mol3.info["spin"] = 2.0 + + state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (3,) + assert state.spin.shape == (3,) + assert state.charge[0].item() == 1.0 + assert state.charge[1].item() == -1.0 + assert state.charge[2].item() == 0.0 + assert state.spin[0].item() == 1.0 + assert state.spin[1].item() == 0.0 + assert state.spin[2].item() == 2.0 + + def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) @@ -117,6 +181,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: assert len(atoms[0]) == 32 +def test_state_to_atoms_with_charge_spin() -> None: + """Test conversion from state with charge/spin to ASE Atoms preserves charge/spin.""" + mol = molecule("H2O") + mol.info["charge"] = 1.0 + mol.info["spin"] = 1.0 + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + atoms = ts.io.state_to_atoms(state) + + assert len(atoms) == 1 + assert isinstance(atoms[0], Atoms) + assert "charge" in atoms[0].info + assert "spin" in atoms[0].info + assert atoms[0].info["charge"] == 1 + assert atoms[0].info["spin"] == 1 + + def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_double_sim_state) @@ -259,6 +340,9 @@ def test_state_round_trip( # since both use their own isotope masses based on species, # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) + # Check charge/spin round trip + assert torch.allclose(sim_state.charge, round_trip_state.charge) + assert torch.allclose(sim_state.spin, round_trip_state.spin) def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 76ff69c1..49d9b62a 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -233,6 +233,7 @@ def velocity_verlet_step[T: MDState]( state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_2) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4a86084c..eaa5b812 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -633,7 +633,7 @@ def npt_langevin_init( logger.warning(msg) # Create the initial state - return NPTLangevinState.from_state( + npt_state = NPTLangevinState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -647,6 +647,8 @@ def npt_langevin_init( cell_masses=cell_masses, cell_alpha=cell_alpha, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1063/1.4901303") @@ -708,6 +710,7 @@ def npt_langevin_step( model_output = model(state) state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Store initial values for integration forces = state.forces @@ -747,6 +750,7 @@ def npt_langevin_step( state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Compute updated pressure force F_p_n_new = _compute_cell_force( @@ -1291,6 +1295,7 @@ def _npt_nose_hoover_inner_step( state.forces = model_output["forces"] state.stress = model_output["stress"] state.energy = model_output["energy"] + state.store_model_extras(model_output) state.cell_position = cell_position state.cell_momentum = cell_momentum state.cell_mass = cell_mass @@ -1444,7 +1449,7 @@ def npt_nose_hoover_init( logger.warning(msg) # Create initial state - return NPTNoseHooverState.from_state( + npt_state = NPTNoseHooverState.from_state( state, momenta=momenta, energy=energy, @@ -1460,6 +1465,8 @@ def npt_nose_hoover_init( barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1080/00268979600100761") @@ -2082,6 +2089,7 @@ def npt_crescale_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt_tensor / 2) @@ -2157,6 +2165,7 @@ def npt_crescale_independent_lengths_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2233,6 +2242,7 @@ def npt_crescale_average_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2310,6 +2320,7 @@ def npt_crescale_isotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2383,7 +2394,7 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState.from_state( + npt_state = NPTCRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -2392,3 +2403,5 @@ def npt_crescale_init( tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) + npt_state.store_model_extras(model_output) + return npt_state diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 07f3064b..316ef78c 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,12 +57,14 @@ def nve_init( state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state def nve_step( @@ -100,5 +102,6 @@ def nve_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8e74bf85..841399b4 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -126,12 +126,14 @@ def nvt_langevin_init( kT, state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state @dcite("10.1098/rspa.2016.0138") @@ -191,6 +193,7 @@ def nvt_langevin_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_tensor / 2) @@ -321,7 +324,7 @@ def nvt_nose_hoover_init( ) # n_atoms * n_dimensions # Initialize state - return NVTNoseHooverState.from_state( + nh_state = NVTNoseHooverState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -330,6 +333,8 @@ def nvt_nose_hoover_init( chain=chain_fns.initialize(dof_per_system, KE, kT_tensor), _chain_fns=chain_fns, ) + nh_state.store_model_extras(model_output) + return nh_state @dcite("10.1080/00268979600100761") @@ -609,12 +614,14 @@ def nvt_vrescale_init( state.rng, ) - return NVTVRescaleState.from_state( + vr_state = NVTVRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + vr_state.store_model_extras(model_output) + return vr_state @dcite("10.1063/1.2408420") diff --git a/torch_sim/io.py b/torch_sim/io.py index edce091e..6044c813 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -32,11 +32,17 @@ description="ASE: Atomic Simulation Environment", path="ase", ) -def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: +def state_to_atoms( + state: "ts.SimState", + system_extras_keys: list[str] | None = None, + atom_extras_keys: list[str] | None = None, +) -> list["Atoms"]: """Convert a SimState to a list of ASE Atoms objects. Args: state (SimState): Batched state containing positions, cell, and atomic numbers + system_extras_keys: Keys for per-system extras to include in atoms.info + atom_extras_keys: Keys for per-atom extras to include in atoms.arrays Returns: list[Atoms]: ASE Atoms objects, one per system @@ -70,10 +76,6 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: else np.array([state.pbc] * 3 if isinstance(state.pbc, bool) else state.pbc) ) - # Extract charge and spin if available (per-system attributes) - charge = state.charge.detach().cpu().numpy() if state.charge is not None else None - spin = state.spin.detach().cpu().numpy() if state.spin is not None else None - atoms_list = [] for sys_idx in np.unique(system_indices): mask = system_indices == sys_idx @@ -91,11 +93,18 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc_for_sys ) - # Preserve charge and spin in atoms.info (as integers for FairChem compatibility) - if charge is not None: - atoms.info["charge"] = int(charge[sys_idx].item()) - if spin is not None: - atoms.info["spin"] = int(spin[sys_idx].item()) + # Write system extras to atoms.info + # charge/spin stored as int scalars for FairChem compatibility + if system_extras_keys is not None: + for key in system_extras_keys: + val = state.system_extras[key][sys_idx].detach().cpu().numpy() + atoms.info[key] = val + + # Write atom extras to atoms.arrays + if atom_extras_keys is not None: + for key in atom_extras_keys: + val = state.atom_extras[key][mask].detach().cpu().numpy() + atoms.arrays[key] = val atoms_list.append(atoms) @@ -244,6 +253,8 @@ def atoms_to_state( atoms: "Atoms | list[Atoms]", device: torch.device | None = None, dtype: torch.dtype | None = None, + system_extras_keys: list[str] | None = None, + atom_extras_keys: list[str] | None = None, ) -> "ts.SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. @@ -252,6 +263,10 @@ def atoms_to_state( device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) + system_extras_keys (list[str]): Optional list of keys to read from atoms.info + into _system_extras + atom_extras_keys (list[str]): Optional list of keys to read from atoms.arrays + into _atom_extras Returns: SimState: TorchSim SimState object. @@ -298,12 +313,25 @@ def atoms_to_state( if not all(np.all(np.equal(at.pbc, atoms_list[0].pbc)) for at in atoms_list[1:]): raise ValueError("All systems must have the same periodic boundary conditions") - charge = torch.tensor( - [at.info.get("charge", 0.0) for at in atoms_list], dtype=dtype, device=device - ) - spin = torch.tensor( - [at.info.get("spin", 0.0) for at in atoms_list], dtype=dtype, device=device - ) + _system_extras: dict[str, torch.Tensor] = {} + if system_extras_keys: + for key in system_extras_keys: + vals = [at.info.get(key) for at in atoms_list] + non_none_vals = [v for v in vals if v is not None] + if len(non_none_vals) == len(vals): + _system_extras[key] = torch.tensor( + np.stack(non_none_vals), dtype=dtype, device=device + ) + + _atom_extras: dict[str, torch.Tensor] = {} + if atom_extras_keys: + for key in atom_extras_keys: + arrays = [at.arrays.get(key) for at in atoms_list] + non_none_arrays = [a for a in arrays if a is not None] + if len(non_none_arrays) == len(arrays): + _atom_extras[key] = torch.tensor( + np.concatenate(non_none_arrays), dtype=dtype, device=device + ) return ts.SimState( positions=positions, @@ -312,8 +340,8 @@ def atoms_to_state( pbc=atoms_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, - charge=charge, - spin=spin, + _system_extras=_system_extras, + _atom_extras=_atom_extras, ) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 8aa6bb5e..d09878de 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -243,7 +243,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 and primitive BCC iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase.build import bulk + from ase.build import bulk, molecule for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): @@ -273,6 +273,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 system_idx = sim_state.system_idx og_system_idx = system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() + og_charge = sim_state.charge.clone() + og_spin = sim_state.spin.clone() if check_detached and hasattr(model, "retain_graph"): model.__dict__["retain_graph"] = True @@ -293,6 +295,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") + if not torch.allclose(og_charge, sim_state.charge): + raise ValueError(f"{og_charge=} != {sim_state.charge=}") + if not torch.allclose(og_spin, sim_state.spin): + raise ValueError(f"{og_spin=} != {sim_state.spin=}") # assert model output has the correct keys if "energy" not in model_output: @@ -407,3 +413,43 @@ def validate_model_outputs( # noqa: C901, PLR0915 "vector: max diff = " f"{(shifted_output['stress'] - si_model_output['stress']).abs().max()}" ) + + # Test that models can handle non-zero charge and spin + benzene_atoms = molecule("C6H6") + benzene_atoms.info["charge"] = 1.0 + benzene_atoms.info["spin"] = 1.0 + charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype) + + # Ensure state has charge/spin before testing model + if charged_state.charge is None or charged_state.spin is None: + raise ValueError( + "atoms_to_state did not extract charge/spin. " + "Cannot test model charge/spin handling." + ) + + # Test that model can handle charge/spin without crashing + og_charged_charge = charged_state.charge.clone() + og_charged_spin = charged_state.spin.clone() + try: + charged_output = model.forward(charged_state) + except Exception as e: + raise ValueError( + "Model failed to handle non-zero charge/spin. " + "Models must be able to process states with charge and spin values. " + ) from e + + # Verify model didn't mutate charge/spin + if not torch.allclose(og_charged_charge, charged_state.charge): + raise ValueError( + f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}" + ) + if not torch.allclose(og_charged_spin, charged_state.spin): + raise ValueError( + f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}" + ) + # Verify output shape is still correct + if charged_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect with charge/spin: " + f"{charged_output['energy'].shape=} != (1,)" + ) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index e13e3ce8..75078670 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -336,6 +336,13 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Propagate additional model outputs (e.g. dipole, charges, etc.) + for key, val in out.items(): + if key not in ("energy", "forces", "stress") and isinstance( + val, torch.Tensor + ): + results[key] = val.detach() + return results diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 04dfde31..8a4a0d37 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -223,7 +223,7 @@ def swap_mc_init( """ model_output = model(state) - return SwapMCState( + mc_state = SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, @@ -233,6 +233,8 @@ def swap_mc_init( energy=model_output["energy"], _constraints=state.constraints, ) + mc_state.store_model_extras(model_output) + return mc_state def swap_mc_step( @@ -292,5 +294,6 @@ def swap_mc_step( state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() + state.store_model_extras(model_output) return state diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index 009fe9bb..16648950 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -12,13 +12,13 @@ try: from vesin import NeighborList as VesinNeighborList except ImportError: - VesinNeighborList = None # type: ignore[assignment] + VesinNeighborList = None try: from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: - VesinNeighborListTorch = None # ty:ignore[invalid-assignment] + VesinNeighborListTorch = None VESIN_AVAILABLE = VesinNeighborList is not None VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index c3a344cf..c9ada499 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -25,13 +25,13 @@ _clamp_deform_grad_log, frechet_cell_filter_init, ) +from torch_sim.optimizers.state import BFGSState from torch_sim.state import SimState if TYPE_CHECKING: from torch_sim.models.interface import ModelInterface from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs - from torch_sim.optimizers.state import BFGSState BFGS_EPS = 1e-7 # eps kept same as ASE's BFGS. @@ -115,8 +115,6 @@ def bfgs_init( Returns: BFGSState or CellBFGSState if cell_filter is provided """ - from torch_sim.optimizers import BFGSState, CellBFGSState - device: torch.device = model.device dtype: torch.dtype = model.dtype @@ -143,6 +141,19 @@ def bfgs_init( n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] + bfgs_attrs = { + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # [S] + } + if cell_filter is not None: # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) @@ -153,59 +164,31 @@ def bfgs_init( cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - # Note (AG): At initialization, deform_grad is identity, so we have: - # fractional = Cartesian / cell and scaled forces = forces @ I = forces - # For ASE compatibility, we need to store prev_positions as fractional coords - # and prev_forces as scaled forces - - # Get initial deform_grad (identity at start since reference_cell = current_cell) + # At initialization, deform_grad is identity, so fractional = Cartesian + # and scaled forces = forces. For ASE compatibility, store prev_positions + # as fractional coords and prev_forces as scaled forces. reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = cell_filters.deform_grad( reference_cell.mT, state.cell.mT ) # [S, 3, 3] - # Initial fractional positions = solve(deform_grad, positions) = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D_ext, D_ext] - # Note (AG): Store fractional positions and scaled forces - # for ASE compatibility - "prev_forces": scaled_forces, # [N, 3] (scaled) - "prev_positions": frac_positions, # [N, 3] (fractional) - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "reference_cell": reference_cell, # [S, 3, 3] - "cell_filter": cell_filter_funcs, - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - cell_state = CellBFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_attrs["hessian"] = hessian # [S, D_ext, D_ext] + bfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) + bfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + bfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + bfgs_attrs["cell_filter"] = cell_filter_funcs + + cell_state = CellBFGSState.from_state(state, **bfgs_attrs) # Initialize cell-specific attributes (cell_positions, cell_forces, etc.) # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -215,6 +198,7 @@ def bfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms @@ -222,31 +206,11 @@ def bfgs_init( hessian = torch.eye(dim, device=device, dtype=dtype).unsqueeze(0).repeat( n_systems, 1, 1 ) * alpha_t.view(n_systems, 1, 1) # [S, D, D] + bfgs_attrs["hessian"] = hessian # [S, D, D] - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D, D] - "prev_forces": forces.clone(), # [N, 3] - "prev_positions": state.positions.clone(), # [N, 3] - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - return BFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_state = BFGSState.from_state(state, **bfgs_attrs) + bfgs_state.store_model_extras(model_output) + return bfgs_state def bfgs_step( # noqa: C901, PLR0915 @@ -550,6 +514,7 @@ def bfgs_step( # noqa: C901, PLR0915 state.energy = model_output["energy"] # [S] if "stress" in model_output: state.stress = model_output["stress"] # [S, 3, 3] + state.store_model_extras(model_output) # Update cell forces for next step # Update cell forces for cell state: [S, 3, 3] diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 8efcb3a7..e45c7ec8 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -106,9 +106,12 @@ def fire_init( cell_state.cell_forces.shape, torch.nan, device=device, dtype=dtype ) + cell_state.store_model_extras(model_output) return cell_state # Create regular FireState without cell optimization - return FireState.from_state(state, **fire_attrs) + fire_state = FireState.from_state(state, **fire_attrs) + fire_state.store_model_extras(model_output) + return fire_state def fire_step( @@ -173,7 +176,7 @@ def fire_step( return step_func(state, **step_func_kwargs) # ty: ignore[invalid-argument-type] -def _vv_fire_step[T: "FireState | CellFireState"]( +def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state: T, model: "ModelInterface", *, @@ -215,6 +218,7 @@ def _vv_fire_step[T: "FireState | CellFireState"]( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): @@ -465,6 +469,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 6f940ff0..7356ffe8 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -53,6 +53,8 @@ def gradient_descent_init( "stress": stress, } + state.store_model_extras(model_output) + if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) optim_attrs["reference_cell"] = state.cell.clone() @@ -112,6 +114,7 @@ def gradient_descent_step( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellOptimState): diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index f8413399..cb88c657 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -192,22 +192,10 @@ def lbfgs_init( if step_size_tensor.ndim == 0: step_size_tensor = step_size_tensor.expand(n_systems) - common_args = { - # Copy SimState attributes - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - # Optimization state + lbfgs_attrs = { "forces": forces, # [N, 3] "energy": energy, # [S] "stress": stress, # [S, 3, 3] or None - # L-BFGS specific state "prev_forces": forces.clone(), # [N, 3] "prev_positions": state.positions.clone(), # [N, 3] "s_history": s_history, # [S, 0, M, 3] @@ -227,41 +215,35 @@ def lbfgs_init( reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = deform_grad(reference_cell.mT, state.cell.mT) # [S, 3, 3] - # Initial fractional positions = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) # [N, 3] - common_args["reference_cell"] = reference_cell # [S, 3, 3] - common_args["cell_filter"] = cell_filter_funcs - # Store fractional positions and scaled forces for ASE compatibility - common_args["prev_positions"] = frac_positions # [N, 3] - common_args["prev_forces"] = scaled_forces # [N, 3] + lbfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + lbfgs_attrs["cell_filter"] = cell_filter_funcs + lbfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + lbfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) # Extended per-system history includes cell DOFs (3 "virtual atoms" per system) - # History shape: [S, H, M+3, 3] where M = global_max_atoms extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3 - common_args["s_history"] = torch.zeros( + lbfgs_attrs["s_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - common_args["y_history"] = torch.zeros( + lbfgs_attrs["y_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - cell_state = CellLBFGSState(**common_args) # ty: ignore[invalid-argument-type] + cell_state = CellLBFGSState.from_state(state, **lbfgs_attrs) # Initialize cell-specific attributes # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -271,9 +253,12 @@ def lbfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state - return LBFGSState(**common_args) # ty: ignore[invalid-argument-type] + lbfgs_state = LBFGSState.from_state(state, **lbfgs_attrs) + lbfgs_state.store_model_extras(model_output) + return lbfgs_state def lbfgs_step( # noqa: PLR0915, C901 @@ -536,6 +521,7 @@ def lbfgs_step( # noqa: PLR0915, C901 new_forces = model_output["forces"] # [N, 3] new_energy = model_output["energy"] # [S] new_stress = model_output.get("stress") # [S, 3, 3] or None + state.store_model_extras(model_output) # Update cell forces for next step: [S, 3, 3] if isinstance(state, CellLBFGSState): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 72c25c07..7e66ca26 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState -from torch_sim.state import SimState +from torch_sim.state import _CANONICAL_MODEL_KEYS, SimState from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike from torch_sim.units import UnitSystem @@ -732,7 +732,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) -def static( +def static( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -836,8 +836,25 @@ class StaticState(SimState): else torch.full_like(sub_state.cell, fill_value=float("nan")) ), ) + static_state.store_model_extras(model_outputs) props = trajectory_reporter.report(static_state, 0, model=model) + + # Merge extra model outputs into per-system property dicts + # TODO: this should be cleaner? + extra_keys = {k for k in model_outputs if k not in _CANONICAL_MODEL_KEYS} + if extra_keys: + for sys_idx, sys_props in enumerate(props): + for key in extra_keys: + val = model_outputs[key] + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + if val.shape[0] == static_state.n_atoms: + mask = static_state.system_idx == sys_idx + sys_props[key] = val[mask] + elif val.shape[0] == static_state.n_systems: + sys_props[key] = val[sys_idx : sys_idx + 1] + all_props.extend(props) if tqdm_pbar: diff --git a/torch_sim/state.py b/torch_sim/state.py index 390b55fd..e9079d91 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -5,11 +5,12 @@ """ import copy +import functools import importlib import typing from collections import defaultdict from collections.abc import Generator, Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self import torch @@ -32,6 +33,10 @@ ) +# Canonical model output keys that are handled explicitly by integrators/runners +_CANONICAL_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) + + def coerce_prng(rng: PRNGLike, device: DeviceLikeType | None) -> torch.Generator: """Coerce an int seed or existing Generator into a ``torch.Generator``. @@ -67,6 +72,30 @@ def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: return system_idx +_EXTRAS_COMPAT_KEYS = frozenset({"charge", "spin"}) + + +def _wrap_init_for_extras(cls: type) -> None: + """Wrap a dataclass __init__ to route unknown kwargs into _system_extras.""" + original_init = cls.__init__ + all_fields = {f.name for f in fields(cls)} + + @functools.wraps(original_init) + def _wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: + extras = kwargs.get("_system_extras") + if extras is None: + extras = {} + kwargs["_system_extras"] = extras + unknown = [k for k in kwargs if k not in all_fields] + for key in unknown: + val = kwargs.pop(key) + if val is not None: + extras[key] = val + original_init(self, *args, **kwargs) + + cls.__init__ = _wrapped_init # type: ignore[assignment] + + @dataclass(kw_only=True) class SimState: """State representation for atomistic systems with batched operations support. @@ -129,10 +158,10 @@ class SimState: cell: torch.Tensor pbc: torch.Tensor # coerced from bool/list[bool] by __setattr__ atomic_numbers: torch.Tensor - charge: torch.Tensor | None = field(default=None) - spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ - _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 + _constraints: list["Constraint"] = field(default_factory=list) + _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) + _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) _rng: PRNGLike = field(default=None, repr=False) if TYPE_CHECKING: @@ -145,11 +174,10 @@ def __init__( # noqa: D107 cell: torch.Tensor, pbc: torch.Tensor | list[bool] | bool, atomic_numbers: torch.Tensor, - charge: torch.Tensor | None = None, - spin: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, _constraints: list[Constraint] | None = None, _rng: PRNGLike = None, + **kwargs: Any, ) -> None: ... _atom_attributes: ClassVar[set[str]] = { @@ -158,7 +186,7 @@ def __init__( # noqa: D107 "atomic_numbers", "system_idx", } - _system_attributes: ClassVar[set[str]] = {"cell", "charge", "spin"} + _system_attributes: ClassVar[set[str]] = {"cell"} _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} @property @@ -171,8 +199,20 @@ def rng(self) -> torch.Generator: def rng(self, value: PRNGLike) -> None: self._rng = value - def __setattr__(self, name: str, value: object) -> None: - """Coerce pbc and system_idx on every assignment.""" + def __setattr__(self, name: str, value: object) -> None: # noqa: C901 + """Coerce pbc and system_idx on every assignment. + + Routes charge/spin writes to _system_extras for backward compatibility. + """ + if name in _EXTRAS_COMPAT_KEYS: + try: + extras = object.__getattribute__(self, "_system_extras") + except AttributeError: + extras = {} + super().__setattr__("_system_extras", extras) + if value is not None: + extras[name] = value + return if name == "pbc" and not isinstance(value, torch.Tensor): if isinstance(value, bool): value = [value] * 3 @@ -210,14 +250,14 @@ def __post_init__(self) -> None: # noqa: C901 if self.constraints: validate_constraints(self.constraints, state=self) - if self.charge is None: - self.charge = torch.zeros(n_systems, device=self.device, dtype=self.dtype) - elif self.charge.shape[0] != n_systems: - raise ValueError(f"Charge must have shape (n_systems={n_systems},)") - if self.spin is None: - self.spin = torch.zeros(n_systems, device=self.device, dtype=self.dtype) - elif self.spin.shape[0] != n_systems: - raise ValueError(f"Spin must have shape (n_systems={n_systems},)") + if "charge" not in self._system_extras: + self._system_extras["charge"] = torch.zeros( + n_systems, device=self.device, dtype=self.dtype + ) + if "spin" not in self._system_extras: + self._system_extras["spin"] = torch.zeros( + n_systems, device=self.device, dtype=self.dtype + ) if self.cell.ndim != 3: self.cell = self.cell.unsqueeze(0) @@ -246,6 +286,29 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + # Validate extras shapes and prevent shadowing + all_attrs = self._get_all_attributes() + for key, val in self._system_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"System extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"System extra '{key}' must be a torch.Tensor") + if val.shape[0] != n_systems: + raise ValueError( + f"System extra '{key}' leading dim must be " + f"n_systems={n_systems}, got {val.shape[0]}" + ) + for key, val in self._atom_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"Atom extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"Atom extra '{key}' must be a torch.Tensor") + if val.shape[0] != self.n_atoms: + raise ValueError( + f"Atom extra '{key}' leading dim must be " + f"n_atoms={self.n_atoms}, got {val.shape[0]}" + ) + @classmethod def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" @@ -253,9 +316,72 @@ def _get_all_attributes(cls) -> set[str]: cls._atom_attributes | cls._system_attributes | cls._global_attributes - | {"_constraints"} + | {"_constraints", "_system_extras", "_atom_extras"} + ) + + def __getattr__(self, name: str) -> Any: + """Allow attribute-style access to extras dict entries.""" + # Guard: don't look up private attrs in extras (avoids recursion during init) + if name.startswith("_"): + raise AttributeError(name) + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + return extras[name] + + # Raise AttributeError so that Python's getattr(obj, name, default), + # hasattr(obj, name), and other descriptor-protocol machinery work correctly. + raise AttributeError( + f"'{type(self).__name__}' has no attribute or extra '{name}'" ) + @property + def system_extras(self) -> dict[str, torch.Tensor]: + """Get the system extras.""" + return self._system_extras + + @property + def atom_extras(self) -> dict[str, torch.Tensor]: + """Get the atom extras.""" + return self._atom_extras + + def has_extras(self, key: str) -> bool: + """Check if an extras key exists.""" + return key in self._system_extras or key in self._atom_extras + + def store_model_extras(self, model_output: dict[str, torch.Tensor]) -> None: + """Store non-canonical model outputs into state extras (in-place). + + Any key in *model_output* that is not in ``{"energy", "forces", "stress"}`` + is classified by its leading dimension: + + * ``n_atoms`` → stored in ``_atom_extras`` + * ``n_systems`` → stored in ``_system_extras`` + * otherwise → skipped (ambiguity or scalar) + + When ``n_atoms == n_systems`` (single-atom system), the tensor is stored as + per-atom by convention. + + Args: + model_output: Full dict returned by ``model.forward()``. + """ + n_atoms = self.n_atoms + n_systems = self.n_systems + + for key, val in model_output.items(): + if key in _CANONICAL_MODEL_KEYS: + continue + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + leading = val.shape[0] + if leading == n_atoms: + self._atom_extras[key] = val + elif leading == n_systems: + self._system_extras[key] = val + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -486,8 +612,18 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: if attr_name in cls._get_all_attributes(): attrs[attr_name] = cls._clone_attr(attr_value) - # Add/override with additional attributes - attrs.update(additional_attrs) + # Route additional_attrs: known attrs go directly, unknown tensor attrs + # go to _system_extras (backward compat for charge/spin and extensibility) + all_known = cls._get_all_attributes() + for key, val in additional_attrs.items(): + if key in all_known: + attrs[key] = val + elif isinstance(val, torch.Tensor): + if "_system_extras" not in attrs: + attrs["_system_extras"] = {} + attrs["_system_extras"][key] = val + else: + attrs[key] = val return cls(**attrs) @@ -595,9 +731,13 @@ def __init_subclass__(cls, **kwargs) -> None: Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. + + Also wraps __init__ to pop deprecated charge/spin kwargs and route them + to _system_extras for backward compatibility. """ cls._assert_no_tensor_attributes_can_be_none() cls._assert_all_attributes_have_defined_scope() + _wrap_init_for_extras(cls) super().__init_subclass__(**kwargs) @classmethod @@ -606,7 +746,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) - exceptions = {"system_idx", "charge", "spin"} + exceptions = {"system_idx"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items(): @@ -684,6 +824,9 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: ) +_wrap_init_for_extras(SimState) + + @dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" @@ -769,7 +912,7 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") -def _state_to_device[T: SimState]( +def _state_to_device[T: SimState]( # noqa: C901 state: T, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> T: """Convert the SimState to a new device and dtype. @@ -797,11 +940,25 @@ def _state_to_device[T: SimState]( elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(device=device) for k, v in attrs[extras_key].items() + } + if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + + # Update floating point extras to new dtype + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(dtype=dtype) if v.is_floating_point() else v + for k, v in attrs[extras_key].items() + } return type(state)(**attrs) @@ -894,6 +1051,13 @@ def _filter_attrs_by_index( val[system_indices] if isinstance(val, torch.Tensor) else val ) + filtered_attrs["_system_extras"] = { + key: val[system_indices] for key, val in state.system_extras.items() + } + filtered_attrs["_atom_extras"] = { + key: val[atom_indices] for key, val in state.atom_extras.items() + } + return filtered_attrs @@ -926,6 +1090,14 @@ def _split_state[T: SimState](state: T) -> list[T]: global_attrs = dict(get_attrs_for_scope(state, "global")) + split_system_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._system_extras.items(): # noqa: SLF001 + split_system_extras[key] = list(torch.split(val, 1, dim=0)) + + split_atom_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._atom_extras.items(): # noqa: SLF001 + split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) @@ -952,6 +1124,12 @@ def _split_state[T: SimState](state: T) -> list[T]: **per_system_dict, # Add the global attributes **global_attrs, + "_system_extras": { + key: split_system_extras[key][sys_idx] for key in split_system_extras + }, + "_atom_extras": { + key: split_atom_extras[key][sys_idx] for key in split_atom_extras + }, } start_idx = int(cumsum_atoms[sys_idx].item()) @@ -1098,6 +1276,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) + atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] system_offset = 0 num_atoms_per_state = [] @@ -1119,6 +1299,12 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) + # Collect extras + for key, val in state.system_extras.items(): + system_extras_tensors[key].append(val) + for key, val in state.atom_extras.items(): + atom_extras_tensors[key].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -1189,6 +1375,14 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Concatenate extras + concatenated["_system_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in system_extras_tensors.items() + } + concatenated["_atom_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in atom_extras_tensors.items() + } + # Merge constraints constraint_lists = [state.constraints for state in states] num_systems_per_state = [state.n_systems for state in states] From 5e7af8eb6bf9d71932c3918940eb38c8ffba872e Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 19:50:58 -0400 Subject: [PATCH 2/5] remove privileged role of spin and charge --- torch_sim/state.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index e9079d91..e16b219e 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -72,9 +72,6 @@ def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: return system_idx -_EXTRAS_COMPAT_KEYS = frozenset({"charge", "spin"}) - - def _wrap_init_for_extras(cls: type) -> None: """Wrap a dataclass __init__ to route unknown kwargs into _system_extras.""" original_init = cls.__init__ @@ -202,17 +199,20 @@ def rng(self, value: PRNGLike) -> None: def __setattr__(self, name: str, value: object) -> None: # noqa: C901 """Coerce pbc and system_idx on every assignment. - Routes charge/spin writes to _system_extras for backward compatibility. + Routes writes to existing extras keys back into their extras dict. """ - if name in _EXTRAS_COMPAT_KEYS: - try: - extras = object.__getattribute__(self, "_system_extras") - except AttributeError: - extras = {} - super().__setattr__("_system_extras", extras) - if value is not None: - extras[name] = value - return + if not name.startswith("_"): + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + if value is not None: + extras[name] = value + else: + del extras[name] + return if name == "pbc" and not isinstance(value, torch.Tensor): if isinstance(value, bool): value = [value] * 3 @@ -250,15 +250,6 @@ def __post_init__(self) -> None: # noqa: C901 if self.constraints: validate_constraints(self.constraints, state=self) - if "charge" not in self._system_extras: - self._system_extras["charge"] = torch.zeros( - n_systems, device=self.device, dtype=self.dtype - ) - if "spin" not in self._system_extras: - self._system_extras["spin"] = torch.zeros( - n_systems, device=self.device, dtype=self.dtype - ) - if self.cell.ndim != 3: self.cell = self.cell.unsqueeze(0) From 22476c5cba61e79f99bd513ad30d2cf267db679c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 20:02:34 -0400 Subject: [PATCH 3/5] fix: down to 28 test failures --- torch_sim/state.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index e16b219e..5eaba48e 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -722,13 +722,9 @@ def __init_subclass__(cls, **kwargs) -> None: Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. - - Also wraps __init__ to pop deprecated charge/spin kwargs and route them - to _system_extras for backward compatibility. """ cls._assert_no_tensor_attributes_can_be_none() cls._assert_all_attributes_have_defined_scope() - _wrap_init_for_extras(cls) super().__init_subclass__(**kwargs) @classmethod From 465c74f30d618770df95d0f03a269ff8fc2852ab Mon Sep 17 00:00:00 2001 From: Stefano Falletta <49149059+falletta@users.noreply.github.com> Date: Thu, 26 Mar 2026 13:11:34 -0400 Subject: [PATCH 4/5] Fixes to Extensible Extras PR (#526) --- tests/test_extras.py | 5 +++-- tests/test_nbody.py | 14 +++++++------- torch_sim/elastic.py | 2 ++ torch_sim/io.py | 30 ++++++++++++++++++++++-------- torch_sim/models/fairchem.py | 4 ++-- torch_sim/models/mace.py | 4 ++-- torch_sim/state.py | 20 +++++++++++++++++--- 7 files changed, 55 insertions(+), 24 deletions(-) diff --git a/tests/test_extras.py b/tests/test_extras.py index 75947706..8c1eef80 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -75,8 +75,9 @@ def test_store_model_extras_canonical_keys_not_stored( "stress": torch.randn(state.n_systems, 3, 3), } ) - assert not state._system_extras # noqa: SLF001 - assert not state._atom_extras # noqa: SLF001 + for key in ("energy", "forces", "stress"): + assert key not in state._system_extras # noqa: SLF001 + assert key not in state._atom_extras # noqa: SLF001 def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): """Tensors with leading dim == n_systems go into system_extras.""" diff --git a/tests/test_nbody.py b/tests/test_nbody.py index e235cd62..5da5a45c 100644 --- a/tests/test_nbody.py +++ b/tests/test_nbody.py @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None: result = build_triplets(edge_index, n_atoms) - assert result["trip_in"].device == dev - assert result["trip_out"].device == dev - assert result["center_atom"].device == dev + assert result["trip_in"].device.type == dev.type + assert result["trip_out"].device.type == dev.type + assert result["center_atom"].device.type == dev.type @pytest.mark.parametrize( @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None: internal_cell_offsets, ) - assert result["quad_c_to_a_edge"].device == dev - assert result["quad_d_to_b_trip_idx"].device == dev - assert result["d_to_b_edge"].device == dev - assert result["c_to_a_edge"].device == dev + assert result["quad_c_to_a_edge"].device.type == dev.type + assert result["quad_d_to_b_trip_idx"].device.type == dev.type + assert result["d_to_b_edge"].device.type == dev.type + assert result["c_to_a_edge"].device.type == dev.type def test_build_triplets_jit_script() -> None: diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 944cdb81..3efffe7e 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -680,6 +680,8 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> masses=state.masses, pbc=state.pbc, atomic_numbers=state.atomic_numbers, + _system_extras=state._system_extras, + _atom_extras=state._atom_extras, ) diff --git a/torch_sim/io.py b/torch_sim/io.py index 6044c813..5ffe0fcd 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -95,16 +95,22 @@ def state_to_atoms( # Write system extras to atoms.info # charge/spin stored as int scalars for FairChem compatibility - if system_extras_keys is not None: - for key in system_extras_keys: - val = state.system_extras[key][sys_idx].detach().cpu().numpy() - atoms.info[key] = val + _sys_keys = ( + system_extras_keys + if system_extras_keys is not None + else list(state.system_extras) + ) + for key in _sys_keys: + val = state.system_extras[key][sys_idx].detach().cpu().numpy() + atoms.info[key] = val # Write atom extras to atoms.arrays - if atom_extras_keys is not None: - for key in atom_extras_keys: - val = state.atom_extras[key][mask].detach().cpu().numpy() - atoms.arrays[key] = val + _atom_keys = ( + atom_extras_keys if atom_extras_keys is not None else list(state.atom_extras) + ) + for key in _atom_keys: + val = state.atom_extras[key][mask].detach().cpu().numpy() + atoms.arrays[key] = val atoms_list.append(atoms) @@ -314,8 +320,16 @@ def atoms_to_state( raise ValueError("All systems must have the same periodic boundary conditions") _system_extras: dict[str, torch.Tensor] = {} + + # charge and spin always default to 0 for backward compatibility + for key in ("charge", "spin"): + vals = np.array([float(at.info.get(key, 0.0)) for at in atoms_list]) + _system_extras[key] = torch.tensor(vals, dtype=dtype, device=device) + if system_extras_keys: for key in system_extras_keys: + if key in _system_extras: + continue vals = [at.info.get(key) for at in atoms_list] non_none_vals = [v for v in vals if v is not None] if len(non_none_vals) == len(vals): diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1c95b00d..56a6d156 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -223,8 +223,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] pbc=pbc_np if cell is not None else False, ) - charge = sim_state.charge - spin = sim_state.spin + charge = getattr(sim_state, "charge", None) + spin = getattr(sim_state, "spin", None) atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0 diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 75078670..01afdc51 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -304,8 +304,8 @@ def forward( # noqa: C901 edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, - total_charge=state.charge, - total_spin=state.spin, + total_charge=getattr(state, "charge", None), + total_spin=getattr(state, "spin", None), ) # Get model output diff --git a/torch_sim/state.py b/torch_sim/state.py index 5eaba48e..cb36e34f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -974,6 +974,11 @@ def get_attrs_for_scope( for attr_name in attr_names: yield attr_name, getattr(state, attr_name) + if scope == "per-system": + yield from state._system_extras.items() # noqa: SLF001 + elif scope == "per-atom": + yield from state._atom_extras.items() # noqa: SLF001 + def _filter_attrs_by_index( state: SimState, @@ -1029,11 +1034,15 @@ def _filter_attrs_by_index( c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] for name, val in get_attrs_for_scope(state, "per-atom"): + if name in state.atom_extras: + continue filtered_attrs[name] = ( system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices] ) for name, val in get_attrs_for_scope(state, "per-system"): + if name in state.system_extras: + continue filtered_attrs[name] = ( val[system_indices] if isinstance(val, torch.Tensor) else val ) @@ -1065,11 +1074,14 @@ def _split_state[T: SimState](state: T) -> list[T]: split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + if attr_name == "system_idx" or attr_name in state.atom_extras: + continue + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + if attr_name in state.system_extras: + continue if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split @@ -1277,13 +1289,15 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): - if prop == "system_idx": + if prop == "system_idx" or prop in state.atom_extras: # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): + if prop in state.system_extras: + continue per_system_tensors[prop].append(val) # Collect extras From 1445ed02fe708d9895ff78862e1def43798835ab Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 2 Apr 2026 11:53:26 -0400 Subject: [PATCH 5/5] lint: avoid SLF001 for extras in elastic code. --- torch_sim/elastic.py | 24 +++++++----------------- torch_sim/io.py | 2 +- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 3efffe7e..7be97efb 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -666,23 +666,13 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> else: # axis == 5 L[0, 1] += size # xy shear - # Convert positions to fractional coordinates - old_inv = torch.linalg.inv(row_vector_cell) - frac_coords = torch.matmul(positions, old_inv) - - # Apply transformation to cell and convert positions back to cartesian - row_vector_cell = torch.matmul(row_vector_cell, L) - new_positions = torch.matmul(frac_coords, row_vector_cell) - - return SimState( - positions=new_positions, - cell=row_vector_cell.mT.unsqueeze(0), - masses=state.masses, - pbc=state.pbc, - atomic_numbers=state.atomic_numbers, - _system_extras=state._system_extras, - _atom_extras=state._atom_extras, - ) + frac_coords = torch.matmul(positions, torch.linalg.inv(row_vector_cell)) + new_cell = torch.matmul(row_vector_cell, L) + + new_state = state.clone() + new_state.row_vector_cell = new_cell.unsqueeze(0) + new_state.positions = torch.matmul(frac_coords, new_cell) + return new_state def get_elementary_deformations( diff --git a/torch_sim/io.py b/torch_sim/io.py index 5ffe0fcd..029fd5a0 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -255,7 +255,7 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: description="ASE: Atomic Simulation Environment", path="ase", ) -def atoms_to_state( +def atoms_to_state( # noqa: C901 atoms: "Atoms | list[Atoms]", device: torch.device | None = None, dtype: torch.dtype | None = None,