From e7524f04a447ab48618c86bd5caf8f47a706c1bd Mon Sep 17 00:00:00 2001 From: falletta Date: Thu, 26 Mar 2026 09:42:27 -0700 Subject: [PATCH 1/4] fixes to PR --- tests/test_extras.py | 5 +++-- tests/test_nbody.py | 14 +++++++------- torch_sim/elastic.py | 2 ++ torch_sim/io.py | 31 ++++++++++++++++++++++--------- torch_sim/state.py | 25 +++++++++++++++++-------- 5 files changed, 51 insertions(+), 26 deletions(-) diff --git a/tests/test_extras.py b/tests/test_extras.py index 75947706..8c1eef80 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -75,8 +75,9 @@ def test_store_model_extras_canonical_keys_not_stored( "stress": torch.randn(state.n_systems, 3, 3), } ) - assert not state._system_extras # noqa: SLF001 - assert not state._atom_extras # noqa: SLF001 + for key in ("energy", "forces", "stress"): + assert key not in state._system_extras # noqa: SLF001 + assert key not in state._atom_extras # noqa: SLF001 def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): """Tensors with leading dim == n_systems go into system_extras.""" diff --git a/tests/test_nbody.py b/tests/test_nbody.py index e235cd62..5da5a45c 100644 --- a/tests/test_nbody.py +++ b/tests/test_nbody.py @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None: result = build_triplets(edge_index, n_atoms) - assert result["trip_in"].device == dev - assert result["trip_out"].device == dev - assert result["center_atom"].device == dev + assert result["trip_in"].device.type == dev.type + assert result["trip_out"].device.type == dev.type + assert result["center_atom"].device.type == dev.type @pytest.mark.parametrize( @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None: internal_cell_offsets, ) - assert result["quad_c_to_a_edge"].device == dev - assert result["quad_d_to_b_trip_idx"].device == dev - assert result["d_to_b_edge"].device == dev - assert result["c_to_a_edge"].device == dev + assert result["quad_c_to_a_edge"].device.type == dev.type + assert result["quad_d_to_b_trip_idx"].device.type == dev.type + assert result["d_to_b_edge"].device.type == dev.type + assert result["c_to_a_edge"].device.type == dev.type def test_build_triplets_jit_script() -> None: diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 944cdb81..3efffe7e 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -680,6 +680,8 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> masses=state.masses, pbc=state.pbc, atomic_numbers=state.atomic_numbers, + _system_extras=state._system_extras, + _atom_extras=state._atom_extras, ) diff --git a/torch_sim/io.py b/torch_sim/io.py index 6044c813..106956f9 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -94,17 +94,22 @@ def state_to_atoms( ) # Write system extras to atoms.info - # charge/spin stored as int scalars for FairChem compatibility - if system_extras_keys is not None: - for key in system_extras_keys: - val = state.system_extras[key][sys_idx].detach().cpu().numpy() - atoms.info[key] = val + _sys_keys = ( + system_extras_keys + if system_extras_keys is not None + else list(state.system_extras) + ) + for key in _sys_keys: + val = state.system_extras[key][sys_idx].detach().cpu().numpy() + atoms.info[key] = val # Write atom extras to atoms.arrays - if atom_extras_keys is not None: - for key in atom_extras_keys: - val = state.atom_extras[key][mask].detach().cpu().numpy() - atoms.arrays[key] = val + _atom_keys = ( + atom_extras_keys if atom_extras_keys is not None else list(state.atom_extras) + ) + for key in _atom_keys: + val = state.atom_extras[key][mask].detach().cpu().numpy() + atoms.arrays[key] = val atoms_list.append(atoms) @@ -314,8 +319,16 @@ def atoms_to_state( raise ValueError("All systems must have the same periodic boundary conditions") _system_extras: dict[str, torch.Tensor] = {} + + # charge and spin always default to 0 for backward compatibility + for key in ("charge", "spin"): + vals = np.array([float(at.info.get(key, 0.0)) for at in atoms_list]) + _system_extras[key] = torch.tensor(vals, dtype=dtype, device=device) + if system_extras_keys: for key in system_extras_keys: + if key in _system_extras: + continue vals = [at.info.get(key) for at in atoms_list] non_none_vals = [v for v in vals if v is not None] if len(non_none_vals) == len(vals): diff --git a/torch_sim/state.py b/torch_sim/state.py index 5eaba48e..0afbb6da 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -323,8 +323,6 @@ def __getattr__(self, name: str) -> Any: if name in extras: return extras[name] - # Raise AttributeError so that Python's getattr(obj, name, default), - # hasattr(obj, name), and other descriptor-protocol machinery work correctly. raise AttributeError( f"'{type(self).__name__}' has no attribute or extra '{name}'" ) @@ -603,8 +601,6 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: if attr_name in cls._get_all_attributes(): attrs[attr_name] = cls._clone_attr(attr_value) - # Route additional_attrs: known attrs go directly, unknown tensor attrs - # go to _system_extras (backward compat for charge/spin and extensibility) all_known = cls._get_all_attributes() for key, val in additional_attrs.items(): if key in all_known: @@ -974,6 +970,11 @@ def get_attrs_for_scope( for attr_name in attr_names: yield attr_name, getattr(state, attr_name) + if scope == "per-system": + yield from state._system_extras.items() # noqa: SLF001 + elif scope == "per-atom": + yield from state._atom_extras.items() # noqa: SLF001 + def _filter_attrs_by_index( state: SimState, @@ -1029,11 +1030,15 @@ def _filter_attrs_by_index( c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] for name, val in get_attrs_for_scope(state, "per-atom"): + if name in state.atom_extras: + continue filtered_attrs[name] = ( system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices] ) for name, val in get_attrs_for_scope(state, "per-system"): + if name in state.system_extras: + continue filtered_attrs[name] = ( val[system_indices] if isinstance(val, torch.Tensor) else val ) @@ -1065,11 +1070,14 @@ def _split_state[T: SimState](state: T) -> list[T]: split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + if attr_name == "system_idx" or attr_name in state.atom_extras: + continue + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + if attr_name in state.system_extras: + continue if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split @@ -1277,13 +1285,14 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): - if prop == "system_idx": - # skip system_idx, it will be handled below + if prop == "system_idx" or prop in state.atom_extras: continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): + if prop in state.system_extras: + continue per_system_tensors[prop].append(val) # Collect extras From 3a16d4416b8eed01735175af1379a22a57f69952 Mon Sep 17 00:00:00 2001 From: falletta Date: Thu, 26 Mar 2026 09:58:37 -0700 Subject: [PATCH 2/4] reinserted some comments --- torch_sim/io.py | 1 + torch_sim/state.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/torch_sim/io.py b/torch_sim/io.py index 106956f9..5ffe0fcd 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -94,6 +94,7 @@ def state_to_atoms( ) # Write system extras to atoms.info + # charge/spin stored as int scalars for FairChem compatibility _sys_keys = ( system_extras_keys if system_extras_keys is not None diff --git a/torch_sim/state.py b/torch_sim/state.py index 0afbb6da..cb36e34f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -323,6 +323,8 @@ def __getattr__(self, name: str) -> Any: if name in extras: return extras[name] + # Raise AttributeError so that Python's getattr(obj, name, default), + # hasattr(obj, name), and other descriptor-protocol machinery work correctly. raise AttributeError( f"'{type(self).__name__}' has no attribute or extra '{name}'" ) @@ -601,6 +603,8 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: if attr_name in cls._get_all_attributes(): attrs[attr_name] = cls._clone_attr(attr_value) + # Route additional_attrs: known attrs go directly, unknown tensor attrs + # go to _system_extras (backward compat for charge/spin and extensibility) all_known = cls._get_all_attributes() for key, val in additional_attrs.items(): if key in all_known: @@ -1286,6 +1290,7 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): if prop == "system_idx" or prop in state.atom_extras: + # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) From f9c43ddf3b61bd60f9267f8af68a21f53dd20a75 Mon Sep 17 00:00:00 2001 From: falletta Date: Thu, 26 Mar 2026 10:01:21 -0700 Subject: [PATCH 3/4] fix in getting charge and spin --- torch_sim/models/mace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 75078670..01afdc51 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -304,8 +304,8 @@ def forward( # noqa: C901 edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, - total_charge=state.charge, - total_spin=state.spin, + total_charge=getattr(state, "charge", None), + total_spin=getattr(state, "spin", None), ) # Get model output From d7810b2650405a3b78cd6b2c58f8ef1a27c040a8 Mon Sep 17 00:00:00 2001 From: falletta Date: Thu, 26 Mar 2026 10:03:23 -0700 Subject: [PATCH 4/4] fixing fairchem access to charge and spin --- torch_sim/models/fairchem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1c95b00d..56a6d156 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -223,8 +223,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] pbc=pbc_np if cell is not None else False, ) - charge = sim_state.charge - spin = sim_state.spin + charge = getattr(sim_state, "charge", None) + spin = getattr(sim_state, "spin", None) atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0