diff --git a/tests/test_extras.py b/tests/test_extras.py index 75947706..8c1eef80 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -75,8 +75,9 @@ def test_store_model_extras_canonical_keys_not_stored( "stress": torch.randn(state.n_systems, 3, 3), } ) - assert not state._system_extras # noqa: SLF001 - assert not state._atom_extras # noqa: SLF001 + for key in ("energy", "forces", "stress"): + assert key not in state._system_extras # noqa: SLF001 + assert key not in state._atom_extras # noqa: SLF001 def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): """Tensors with leading dim == n_systems go into system_extras.""" diff --git a/tests/test_nbody.py b/tests/test_nbody.py index e235cd62..5da5a45c 100644 --- a/tests/test_nbody.py +++ b/tests/test_nbody.py @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None: result = build_triplets(edge_index, n_atoms) - assert result["trip_in"].device == dev - assert result["trip_out"].device == dev - assert result["center_atom"].device == dev + assert result["trip_in"].device.type == dev.type + assert result["trip_out"].device.type == dev.type + assert result["center_atom"].device.type == dev.type @pytest.mark.parametrize( @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None: internal_cell_offsets, ) - assert result["quad_c_to_a_edge"].device == dev - assert result["quad_d_to_b_trip_idx"].device == dev - assert result["d_to_b_edge"].device == dev - assert result["c_to_a_edge"].device == dev + assert result["quad_c_to_a_edge"].device.type == dev.type + assert result["quad_d_to_b_trip_idx"].device.type == dev.type + assert result["d_to_b_edge"].device.type == dev.type + assert result["c_to_a_edge"].device.type == dev.type def test_build_triplets_jit_script() -> None: diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 944cdb81..3efffe7e 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -680,6 +680,8 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> masses=state.masses, pbc=state.pbc, atomic_numbers=state.atomic_numbers, + _system_extras=state._system_extras, + _atom_extras=state._atom_extras, ) diff --git a/torch_sim/io.py b/torch_sim/io.py index 6044c813..5ffe0fcd 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -95,16 +95,22 @@ def state_to_atoms( # Write system extras to atoms.info # charge/spin stored as int scalars for FairChem compatibility - if system_extras_keys is not None: - for key in system_extras_keys: - val = state.system_extras[key][sys_idx].detach().cpu().numpy() - atoms.info[key] = val + _sys_keys = ( + system_extras_keys + if system_extras_keys is not None + else list(state.system_extras) + ) + for key in _sys_keys: + val = state.system_extras[key][sys_idx].detach().cpu().numpy() + atoms.info[key] = val # Write atom extras to atoms.arrays - if atom_extras_keys is not None: - for key in atom_extras_keys: - val = state.atom_extras[key][mask].detach().cpu().numpy() - atoms.arrays[key] = val + _atom_keys = ( + atom_extras_keys if atom_extras_keys is not None else list(state.atom_extras) + ) + for key in _atom_keys: + val = state.atom_extras[key][mask].detach().cpu().numpy() + atoms.arrays[key] = val atoms_list.append(atoms) @@ -314,8 +320,16 @@ def atoms_to_state( raise ValueError("All systems must have the same periodic boundary conditions") _system_extras: dict[str, torch.Tensor] = {} + + # charge and spin always default to 0 for backward compatibility + for key in ("charge", "spin"): + vals = np.array([float(at.info.get(key, 0.0)) for at in atoms_list]) + _system_extras[key] = torch.tensor(vals, dtype=dtype, device=device) + if system_extras_keys: for key in system_extras_keys: + if key in _system_extras: + continue vals = [at.info.get(key) for at in atoms_list] non_none_vals = [v for v in vals if v is not None] if len(non_none_vals) == len(vals): diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1c95b00d..56a6d156 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -223,8 +223,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] pbc=pbc_np if cell is not None else False, ) - charge = sim_state.charge - spin = sim_state.spin + charge = getattr(sim_state, "charge", None) + spin = getattr(sim_state, "spin", None) atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0 diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 75078670..01afdc51 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -304,8 +304,8 @@ def forward( # noqa: C901 edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, - total_charge=state.charge, - total_spin=state.spin, + total_charge=getattr(state, "charge", None), + total_spin=getattr(state, "spin", None), ) # Get model output diff --git a/torch_sim/state.py b/torch_sim/state.py index 5eaba48e..cb36e34f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -974,6 +974,11 @@ def get_attrs_for_scope( for attr_name in attr_names: yield attr_name, getattr(state, attr_name) + if scope == "per-system": + yield from state._system_extras.items() # noqa: SLF001 + elif scope == "per-atom": + yield from state._atom_extras.items() # noqa: SLF001 + def _filter_attrs_by_index( state: SimState, @@ -1029,11 +1034,15 @@ def _filter_attrs_by_index( c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] for name, val in get_attrs_for_scope(state, "per-atom"): + if name in state.atom_extras: + continue filtered_attrs[name] = ( system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices] ) for name, val in get_attrs_for_scope(state, "per-system"): + if name in state.system_extras: + continue filtered_attrs[name] = ( val[system_indices] if isinstance(val, torch.Tensor) else val ) @@ -1065,11 +1074,14 @@ def _split_state[T: SimState](state: T) -> list[T]: split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + if attr_name == "system_idx" or attr_name in state.atom_extras: + continue + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + if attr_name in state.system_extras: + continue if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split @@ -1277,13 +1289,15 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): - if prop == "system_idx": + if prop == "system_idx" or prop in state.atom_extras: # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): + if prop in state.system_extras: + continue per_system_tensors[prop].append(val) # Collect extras