diff --git a/.gitignore b/.gitignore index b0bdd821..b43f2fef 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,9 @@ docs/reference/torch_sim.* *.hdf5 *.traj +# ignore torch.save outputs +*.pt + # coverage coverage.xml .coverage* diff --git a/docs/user/reproducibility.md b/docs/user/reproducibility.md index 8b348e8f..8c797d11 100644 --- a/docs/user/reproducibility.md +++ b/docs/user/reproducibility.md @@ -79,17 +79,34 @@ Because TorchSim runs batched simulations, all systems in a batch share a single If strict reproducibility is required, keep your batching setup fixed. -### Serialising the RNG state +### Serialising state for reproducible restarts -If you wish to be able to resume a session and ensure determinism you need to persist and reload the `torch.Generator` state. This can be done using `torch.save()` and `torch.Generator().set_state()`: +To resume a simulation and ensure determinism you need to persist and reload the complete state, including the `torch.Generator` RNG state. The simplest approach is to save the full state dict with `torch.save()`: ```python +from dataclasses import asdict +from torch_sim.integrators import MDState + # save +torch.save(asdict(state), "checkpoint.pt") + +# restore (weights_only=False needed for torch.Generator in PyTorch 2.6+) +restored = MDState(**torch.load("checkpoint.pt", weights_only=False)) +``` + +This captures positions, momenta, forces, energy, cell, and the `torch.Generator` in a single file. Since `torch.save()` uses pickle, the generator is serialised automatically. + +> **Pickle caveat**: The `torch.Generator` object in the dict requires `weights_only=False` and may not unpickle across PyTorch versions. For portable checkpoints, save the tensors normally and extract the RNG state as a plain `uint8` tensor via `get_state()` — this loads with `weights_only=True` and is version-safe: + +```python +# save RNG state as a plain uint8 tensor (no pickle needed) rng_state = state.rng.get_state() torch.save(rng_state, "rng_state.pt") # restore gen = torch.Generator(device=state.device) -gen.set_state(torch.load("rng_state.pt")) +gen.set_state(torch.load("rng_state.pt", weights_only=True)) state.rng = gen ``` + +See the [reproducible restart tutorial](../../examples/tutorials/reproducible_restart_tutorial.py) for a complete worked example. diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index c7fc6362..9a7a2ede 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -107,6 +107,10 @@ class HybridSwapMCState(SwapMCState, MDState): ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 ) + def __post_init__(self) -> None: + """Initialize HybridSwapMCState and ensure last_permutation is set.""" + super().__post_init__() + # %% [markdown] """ @@ -127,16 +131,8 @@ class HybridSwapMCState(SwapMCState, MDState): state.rng = 42 md_state = ts.nvt_langevin_init(state=state, model=mace_model, kT=kT) -# Initialize swap Monte Carlo state -swap_state = ts.swap_mc_init(state=md_state, model=mace_model) - -# Create hybrid state combining both -hybrid_state = HybridSwapMCState( - **md_state.attributes, - last_permutation=torch.arange( - md_state.n_atoms, device=md_state.device, dtype=torch.long - ), -) +# Create hybrid state from MD state +hybrid_state = HybridSwapMCState(**md_state.attributes) # %% [markdown] diff --git a/examples/tutorials/reproducible_restart_tutorial.py b/examples/tutorials/reproducible_restart_tutorial.py new file mode 100644 index 00000000..fc855032 --- /dev/null +++ b/examples/tutorials/reproducible_restart_tutorial.py @@ -0,0 +1,251 @@ +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[io]" +# ] +# /// + + +# %% [markdown] +""" +# Reproducible Restarts from Stopped Simulations + +This tutorial demonstrates how to save and restore simulation state to enable +reproducible restarts. We run 50 steps of MD, save the state (including RNG state), +resume for another 50 steps, and verify the result is identical to 100 uninterrupted +steps. + +For stochastic integrators like Langevin dynamics, you must save the random number +generator (RNG) state alongside positions, momenta, and other state variables. +Without it, the stochastic noise will differ on restart and the trajectory will diverge. +""" + +# %% [markdown] +""" +## Setup +""" + +# %% +from dataclasses import asdict +from pathlib import Path + +import torch +import torch_sim as ts +from ase.build import bulk +from torch_sim.integrators import MDState +from torch_sim.models.lennard_jones import LennardJonesModel + +# All generated files go in this directory +restart_dir = Path("restart_files") +restart_dir.mkdir(exist_ok=True) + +seed = 42 +torch.manual_seed(seed) + +lj_model = LennardJonesModel( + sigma=2.0, + epsilon=0.1, + device=torch.device("cpu"), + dtype=torch.float64, +) + +si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + +initial_state = ts.initialize_state( + si_atoms, device=torch.device("cpu"), dtype=torch.float64 +) +initial_state.rng = seed # seed the SimState RNG for reproducibility + +print(f"Initial state: {initial_state.n_atoms} atoms") + +# %% [markdown] +""" +## Part 1: Run 50 Steps, Save State, Resume for 50 More + +We save the complete state with `asdict()` + `torch.save()`. Since `torch.save()` +uses pickle, the `torch.Generator` (RNG) is included automatically — no need to +save it separately. + +**PyTorch 2.6+**: You must pass `weights_only=False` to `torch.load()` when loading +checkpoints that contain `torch.Generator` objects. +""" + +# %% +# Run first 50 steps +trajectory_file_restart = str(restart_dir / "restart_trajectory.h5") +reporter_restart = ts.TrajectoryReporter( + filenames=trajectory_file_restart, + state_frequency=10, + state_kwargs={"save_velocities": True}, +) + +state_after_50 = ts.integrate( + system=initial_state.clone(), + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_restart, +) +reporter_restart.close() + +# Save the complete state (including RNG) in one file +checkpoint_file = str(restart_dir / "checkpoint.pt") +torch.save(asdict(state_after_50), checkpoint_file) +print(f"Saved checkpoint after 50 steps to {checkpoint_file}") + +# %% [markdown] +""" +Now restore the state and continue for another 50 steps: +""" + +# %% +# Load checkpoint (weights_only=False needed for torch.Generator in PyTorch 2.6+) +loaded = torch.load(checkpoint_file, weights_only=False) +restored_state = MDState(**loaded) + +# Verify RNG was restored +assert torch.equal(restored_state.rng.get_state(), state_after_50.rng.get_state()) +print(f"Restored state: {restored_state.n_atoms} atoms, RNG matches ✓") + +# Continue for another 50 steps (append to existing trajectory) +reporter_restart_continued = ts.TrajectoryReporter( + filenames=trajectory_file_restart, + state_frequency=10, + state_kwargs={"save_velocities": True}, + trajectory_kwargs={"mode": "a"}, +) + +state_after_100_restart = ts.integrate( + system=restored_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_restart_continued, +) +reporter_restart_continued.close() +print(f"Completed restart simulation: 50 + 50 = 100 steps") + +# %% [markdown] +""" +## Part 2: Run 100 Steps Continuously for Comparison +""" + +# %% +trajectory_file_continuous = str(restart_dir / "continuous_trajectory.h5") +reporter_continuous = ts.TrajectoryReporter( + filenames=trajectory_file_continuous, + state_frequency=10, + state_kwargs={"save_velocities": True}, +) + +initial_state_continuous = ts.initialize_state( + si_atoms, device=torch.device("cpu"), dtype=torch.float64 +) +initial_state_continuous.rng = seed + +state_after_100_continuous = ts.integrate( + system=initial_state_continuous, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=100, + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_continuous, +) +reporter_continuous.close() +print(f"Completed continuous simulation: 100 steps") + +# %% [markdown] +""" +## Part 3: Compare Trajectories + +Both runs started from the same initial state and seed. The restarted run saved and +restored the RNG state at step 50. If everything is correct, the trajectories should +match exactly: +""" + +# %% +# Compare final RNG states +rng_match = torch.equal( + state_after_100_restart.rng.get_state(), + state_after_100_continuous.rng.get_state(), +) +print(f"Final RNG states match: {rng_match}") + +# Compare trajectories frame by frame +with ts.TorchSimTrajectory(trajectory_file_restart, mode="r") as traj_restart: + positions_restart = traj_restart.get_array("positions") + steps_restart = traj_restart.get_steps("positions") + velocities_restart = traj_restart.get_array("velocities") + +with ts.TorchSimTrajectory(trajectory_file_continuous, mode="r") as traj_continuous: + positions_continuous = traj_continuous.get_array("positions") + steps_continuous = traj_continuous.get_steps("positions") + velocities_continuous = traj_continuous.get_array("velocities") + +matching_steps = sorted(set(steps_restart) & set(steps_continuous)) +print(f"Comparing {len(matching_steps)} frames at steps: {matching_steps}") + +max_pos_diff = 0.0 +max_vel_diff = 0.0 +all_match = True + +for step in matching_steps: + idx_restart = steps_restart.tolist().index(step) + idx_continuous = steps_continuous.tolist().index(step) + + pos_restart = torch.tensor(positions_restart[idx_restart]) + pos_continuous = torch.tensor(positions_continuous[idx_continuous]) + vel_restart = torch.tensor(velocities_restart[idx_restart]) + vel_continuous = torch.tensor(velocities_continuous[idx_continuous]) + + pos_diff = torch.max(torch.abs(pos_restart - pos_continuous)).item() + vel_diff = torch.max(torch.abs(vel_restart - vel_continuous)).item() + max_pos_diff = max(max_pos_diff, pos_diff) + max_vel_diff = max(max_vel_diff, vel_diff) + + if not torch.allclose(pos_restart, pos_continuous, atol=1e-10, rtol=1e-10): + print(f" Step {step}: Position mismatch! Max diff: {pos_diff:.2e}") + all_match = False + if not torch.allclose(vel_restart, vel_continuous, atol=1e-10, rtol=1e-10): + print(f" Step {step}: Velocity mismatch! Max diff: {vel_diff:.2e}") + all_match = False + +assert all_match, ( + f"Restarted and continuous trajectories differ! " + f"Max position difference: {max_pos_diff:.2e}, max velocity difference: {max_vel_diff:.2e}" +) +print("\n✓ Restarted and continuous trajectories match exactly.") + +# %% [markdown] +""" +## Key Takeaways + +1. **Save with `asdict()` + `torch.save()`**: This captures everything — positions, + momenta, forces, energy, cell, and the `torch.Generator` RNG state — in a single + checkpoint file. + +2. **Restore with `MDState(**torch.load(...))`**: The `torch.Generator` is unpickled + automatically, so the RNG state is restored without any extra steps. + +3. **Use append mode** (`trajectory_kwargs={"mode": "a"}`) in `TrajectoryReporter` + to continue an existing trajectory file. + +4. **Pickle caveats**: The `torch.Generator` object in the checkpoint requires pickle + (`weights_only=False`) and may not load across PyTorch versions. For portable + checkpoints, save tensors normally and use `state.rng.get_state()` to extract the + RNG state as a plain `uint8` tensor that works with `weights_only=True`. + +5. **Verify**: Always compare restarted trajectories to continuous runs. +""" + +# %% +# Cleanup +import shutil + +shutil.rmtree(restart_dir) +print(f"Cleaned up {restart_dir}/") diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index ab0e1bc8..ee8feef6 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -191,6 +191,26 @@ def test_monte_carlo_integration( assert torch.all(orig_counts == result_counts) +def test_swap_mc_state_default_last_permutation( + batched_diverse_state: ts.SimState, +) -> None: + """Test that SwapMCState initializes last_permutation to identity if not provided.""" + from torch_sim.monte_carlo import SwapMCState + + state = SwapMCState( + positions=batched_diverse_state.positions, + masses=batched_diverse_state.masses, + cell=batched_diverse_state.cell, + pbc=batched_diverse_state.pbc, + atomic_numbers=batched_diverse_state.atomic_numbers, + system_idx=batched_diverse_state.system_idx, + energy=torch.zeros(batched_diverse_state.n_systems), + ) + assert state.last_permutation is not None + expected_identity = torch.arange(batched_diverse_state.n_atoms, device=DEVICE) + assert torch.equal(state.last_permutation, expected_identity) + + def test_swap_mc_state_attributes(): """Test SwapMCState class structure and inheritance.""" from torch_sim.state import SimState diff --git a/tests/test_state.py b/tests/test_state.py index 6b40e68e..891d53aa 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -348,6 +348,35 @@ def test_initialize_state_from_state(ar_supercell_sim_state: SimState) -> None: assert state.cell.shape == ar_supercell_sim_state.cell.shape +def test_initialize_state_from_list_of_states_with_multiple_systems( + si_double_sim_state: SimState, fe_supercell_sim_state: SimState +) -> None: + """Test initialize_state with list of states that have n_systems > 1.""" + # This should work now that we've removed the arbitrary n_systems == 1 constraint + concatenated = ts.initialize_state([si_double_sim_state, fe_supercell_sim_state]) + + # Should have 3 systems total (2 from si_double + 1 from fe) + assert concatenated.n_systems == 3 + assert concatenated.cell.shape[0] == 3 + + # Check system indices are correct + fe_atoms = fe_supercell_sim_state.n_atoms + expected_system_indices = torch.cat( + [ + si_double_sim_state.system_idx, + torch.full( + (fe_atoms,), 2, dtype=torch.int64, device=fe_supercell_sim_state.device + ), + ] + ) + assert torch.all(concatenated.system_idx == expected_system_indices) + + # Verify we can slice back to original states + assert torch.allclose(concatenated[0].positions, si_double_sim_state[0].positions) + assert torch.allclose(concatenated[1].positions, si_double_sim_state[1].positions) + assert torch.allclose(concatenated[2].positions, fe_supercell_sim_state.positions) + + def test_initialize_state_from_atoms(si_atoms: "Atoms") -> None: """Test conversion from ASE Atoms to SimState.""" state = ts.initialize_state([si_atoms], DEVICE, torch.float64) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 9b1ed810..094a7deb 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -574,17 +574,15 @@ def npt_langevin_init( model_output = model(state) # Initialize momenta if not provided - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) # Initialize cell parameters reference_cell = state.cell.clone() @@ -1375,16 +1373,17 @@ def npt_nose_hoover_init( KE_cell = (cell_momentum.squeeze(-1) ** 2) / (2 * cell_mass) # Initialize momenta - momenta = kwargs.get( - "momenta", - initialize_momenta( + momenta = kwargs.get("momenta") + if momenta is None: + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT_tensor, state.rng, - ), - ) + ) # Compute total DOF for thermostat initialization and a zero KE placeholder dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim @@ -2353,17 +2352,15 @@ def npt_crescale_init( model_output = model(state) # Initialize momenta if not provided - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) # Create the initial state return NPTCRescaleState.from_state( diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index ca724126..07f3064b 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -47,17 +47,15 @@ def nve_init( """ model_output = model(state) - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) return MDState.from_state( state, diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 10f2509a..8e74bf85 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -117,17 +117,15 @@ def nvt_langevin_init( """ model_output = model(state) - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) return MDState.from_state( state, momenta=momenta, @@ -303,12 +301,13 @@ def nvt_nose_hoover_init( atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) model_output = model(state) - momenta = kwargs.get( - "momenta", - initialize_momenta( + momenta = kwargs.get("momenta") + if momenta is None: + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT_tensor, state.rng - ), - ) + ) # Calculate initial kinetic energy per system KE = ts.calc_kinetic_energy( @@ -600,17 +599,15 @@ def nvt_vrescale_init( """ model_output = model(state) - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) return NVTVRescaleState.from_state( state, diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index ab8af45d..800fe819 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -27,7 +27,7 @@ try: from graph_pes import AtomicGraph, GraphPESModel - from graph_pes.atomic_graph import PropertyKey, to_batch + from graph_pes.atomic_graph import PropertyKey from graph_pes.models import load_model except ImportError as exc: @@ -63,40 +63,38 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra Returns: AtomicGraph object representing the batched structures """ - graphs = [] - - for sys_idx in range(state.n_systems): - system_mask = state.system_idx == sys_idx - R = state.positions[system_mask] - Z = state.atomic_numbers[system_mask] - cell = state.row_vector_cell[sys_idx] - # graph-pes models internally trim the neighbor list to the - # model's cutoff value. To ensure no strange edge effects whereby - # edges that are exactly `cutoff` long are included/excluded, - # we bump cutoff + 1e-5 up slightly - - # Create system_idx for this single system (all atoms belong to system 0) - system_idx_single = torch.zeros(R.shape[0], dtype=torch.long, device=R.device) - nl, _system_mapping, shifts = torchsim_nl( - R, cell, state.pbc, cutoff + 1e-5, system_idx_single - ) - - atomic_graph = AtomicGraph( - Z=Z.long(), - R=R, - cell=cell, - neighbour_list=nl.long(), - neighbour_cell_offsets=shifts, - properties={}, - cutoff=cutoff.item(), - other={ - "total_charge": torch.tensor(0.0).to(state.device), - "total_spin": torch.tensor(0.0).to(state.device), - }, - ) - graphs.append(atomic_graph) - - return to_batch(graphs) + # graph-pes models internally trim the neighbor list to the + # model's cutoff value. To ensure no strange edge effects whereby + # edges that are exactly `cutoff` long are included/excluded, + # we bump cutoff + 1e-5 up slightly + nl, _system_mapping, shifts = torchsim_nl( + state.positions, + state.row_vector_cell, + state.pbc, + cutoff + 1e-5, + state.system_idx, + ) + n_atoms_per_system = torch.bincount(state.system_idx) + ptr = torch.zeros(state.n_systems + 1, dtype=torch.long, device=state.device) + ptr[1:] = n_atoms_per_system.cumsum(dim=0) + n_sys = state.n_systems + total_charge = torch.zeros(n_sys, device=state.device) + total_spin = torch.zeros(n_sys, device=state.device) + return AtomicGraph( + Z=state.atomic_numbers.long(), + R=state.positions, + cell=state.row_vector_cell, + neighbour_list=nl.long(), + neighbour_cell_offsets=shifts, + properties={}, + cutoff=cutoff.item(), + other={ + "total_charge": total_charge, + "total_spin": total_spin, + }, + batch=state.system_idx, + ptr=ptr, + ) class GraphPESWrapper(ModelInterface): diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 149bd406..04dfde31 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -16,7 +16,7 @@ ... mc_state = ts.swap_mc_step(model, mc_state, kT=0.1 * units.energy) """ -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -24,6 +24,15 @@ from torch_sim.state import SimState +# Sentinel value for uninitialized last_permutation +_UNINITIALIZED_PERMUTATION = torch.empty(0, dtype=torch.long) + + +def _create_uninitialized_permutation() -> torch.Tensor: + """Create a sentinel tensor for uninitialized last_permutation.""" + return _UNINITIALIZED_PERMUTATION.clone() + + @dataclass(kw_only=True) class SwapMCState(SimState): """State for Monte Carlo simulations with swap moves. @@ -35,15 +44,28 @@ class SwapMCState(SimState): Attributes: energy (torch.Tensor): Energy of the system with shape [batch_size] last_permutation (torch.Tensor): Last permutation applied to the system, - with shape [n_atoms], tracking the moves made for analysis or reversal + with shape [n_atoms], tracking the moves made for analysis or reversal. + If not provided, will be initialized to identity permutation + (torch.arange(n_atoms)) in __post_init__. """ energy: torch.Tensor - last_permutation: torch.Tensor + last_permutation: torch.Tensor = field( + default_factory=_create_uninitialized_permutation + ) _atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 + def __post_init__(self) -> None: + """Initialize SwapMCState and set default last_permutation if needed.""" + super().__post_init__() + # Check if last_permutation is the sentinel (empty tensor) + if self.last_permutation.numel() == 0: + self.last_permutation = torch.arange( + self.n_atoms, device=self.device, dtype=torch.long + ) + def generate_swaps(state: SimState, rng: torch.Generator | None = None) -> torch.Tensor: """Generate atom swaps for a given batched system. @@ -209,7 +231,6 @@ def swap_mc_init( atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, energy=model_output["energy"], - last_permutation=torch.arange(state.n_atoms, device=state.device), _constraints=state.constraints, ) diff --git a/torch_sim/state.py b/torch_sim/state.py index 52b97e79..b3cb8f07 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1225,12 +1225,6 @@ def initialize_state( if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): system: list[SimState] = typing.cast("list[SimState]", system) - if not all(state.n_systems == 1 for state in system): - raise ValueError( - "When providing a list of states, to the initialize_state function, " - "all states must have n_systems == 1. To fix this, you can split the " - "states into individual states with the split_state function." - ) return ts.concatenate_states(system) converters = [