From eea3bcc2504a40e24035521aa2b5819e164d8848 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 14 Mar 2026 22:56:16 -0400 Subject: [PATCH 1/6] wip --- .../reproducible_restart_tutorial.py | 488 ++++++++++++++++++ 1 file changed, 488 insertions(+) create mode 100644 examples/tutorials/reproducible_restart_tutorial.py diff --git a/examples/tutorials/reproducible_restart_tutorial.py b/examples/tutorials/reproducible_restart_tutorial.py new file mode 100644 index 00000000..b514886c --- /dev/null +++ b/examples/tutorials/reproducible_restart_tutorial.py @@ -0,0 +1,488 @@ +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[mace, io]" +# ] +# /// + + +# %% [markdown] +""" +# Reproducible Restarts from Stopped Simulations + +This tutorial demonstrates how to save and restore simulation state to enable +reproducible restarts from stopped simulations. This is essential for long-running +simulations that may need to be paused and resumed, or for checkpointing workflows. + +## Introduction + +When running molecular dynamics simulations, you may need to: +- Pause a simulation and resume it later +- Create checkpoints for long-running simulations +- Ensure that a restarted simulation produces identical results to a continuous run + +To achieve reproducible restarts, you must save not only the atomic positions, velocities, +and other state variables, but also the random number generator (RNG) state. This is +especially important for stochastic integrators like Langevin dynamics, which use random +numbers for both initial momenta sampling and per-step stochastic noise. + +## Key Concepts + +1. **State Saving**: Save the complete simulation state including positions, velocities, + momenta, cell parameters, and RNG state +2. **RNG State**: The `torch.Generator` state must be saved separately using + `get_state()` and restored using `set_state()` +3. **Trajectory Comparison**: Verify that restarted simulations produce identical + trajectories to continuous runs +""" + +# %% [markdown] +""" +## Setup: Initial System and Model + +Let's start by creating a simple system and model for our demonstration: +""" + +# %% +import torch +import torch_sim as ts +from ase.build import bulk +from torch_sim.models.lennard_jones import LennardJonesModel + +# Set up deterministic mode for reproducibility +# Note: This seeds the global RNG, but we'll also seed SimState.rng explicitly +seed = 42 +torch.manual_seed(seed) + +# Create a Lennard-Jones model +lj_model = LennardJonesModel( + sigma=2.0, + epsilon=0.1, + device=torch.device("cpu"), + dtype=torch.float64, +) + +# Create a silicon crystal structure +si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + +# Initialize state and seed the RNG +initial_state = ts.initialize_state( + si_atoms, device=torch.device("cpu"), dtype=torch.float64 +) +initial_state.rng = seed # Critical: seed the SimState RNG for reproducibility + +print(f"Initial state has {initial_state.n_atoms} atoms") +print(f"RNG device: {initial_state.rng.device}") + + +# %% [markdown] +""" +## Part 1: Run 50 Steps, Save State, and Resume + +First, we'll run 50 steps of MD, save the complete state (including RNG state), +then resume for another 50 steps: +""" + +# %% +# Run first 50 steps with trajectory reporting +trajectory_file_restart = "restart_trajectory.h5" +reporter_restart = ts.TrajectoryReporter( + filenames=trajectory_file_restart, + state_frequency=10, # Save state every 10 steps + state_kwargs={"save_velocities": True}, # Save velocities (momenta) for comparison +) + +# Run 50 steps +state_after_50 = ts.integrate( + system=initial_state.clone(), + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_restart, +) + +reporter_restart.close() + +# Check RNG state after 50 steps (before saving) +test_random_after_50_restart = torch.randn(1, generator=state_after_50.rng).item() +print(f"After 50 steps (restart) - Test random: {test_random_after_50_restart:.6f}") + +# Save the complete state including RNG state +state_save_file = "saved_state.pt" +rng_state_save_file = "saved_rng_state.pt" + +# Save the RNG state separately (as recommended in reproducibility docs) +rng_state = state_after_50.rng.get_state() +torch.save(rng_state, rng_state_save_file) + +# Save the state object (this doesn't include the RNG generator itself) +# We'll need to restore the RNG state separately +torch.save( + { + "positions": state_after_50.positions, + "momenta": state_after_50.momenta, + "cell": state_after_50.cell, + "atomic_numbers": state_after_50.atomic_numbers, + "masses": state_after_50.masses, + "pbc": state_after_50.pbc, + "system_idx": state_after_50.system_idx, + "energy": state_after_50.energy, + "forces": state_after_50.forces, + }, + state_save_file, +) + +print(f"Saved state after 50 steps") +print(f"State file: {state_save_file}") +print(f"RNG state file: {rng_state_save_file}") + + +# %% [markdown] +""" +Now let's restore the state and continue for another 50 steps: +""" + +# %% +# Load the saved state +# Note: PyTorch 2.6+ defaults to weights_only=True. Since we're loading our own +# checkpoints, we use weights_only=False to allow loading Generator objects. +saved_data = torch.load(state_save_file, weights_only=False) +rng_state_loaded = torch.load(rng_state_save_file, weights_only=False) + +# Reconstruct the state from saved data +# We need to create an MDState since we have momenta, forces, and energy +from torch_sim.integrators import MDState + +restored_state = MDState( + positions=saved_data["positions"], + momenta=saved_data["momenta"], + cell=saved_data["cell"], + atomic_numbers=saved_data["atomic_numbers"], + masses=saved_data["masses"], + pbc=saved_data["pbc"], + system_idx=saved_data["system_idx"], + energy=saved_data["energy"], + forces=saved_data["forces"], +) + +# Restore the RNG state - this is critical for reproducibility! +gen = torch.Generator(device=restored_state.device) +gen.set_state(rng_state_loaded) +restored_state.rng = gen + +# Verify RNG state was restored correctly +# Draw a random number to verify the RNG state matches +test_random_restored = torch.randn(1, generator=restored_state.rng).item() +print(f"Restored state with {restored_state.n_atoms} atoms") +print(f"Restored RNG device: {restored_state.rng.device}") +print(f"Test random from restored RNG: {test_random_restored:.6f}") + +# Verify the RNG state matches what we saved +# (We need to reload the saved state to compare) +saved_rng_check = torch.Generator(device=restored_state.device) +saved_rng_check.set_state(rng_state_loaded) +test_random_saved = torch.randn(1, generator=saved_rng_check).item() +print(f"Test random from saved RNG state: {test_random_saved:.6f}") +assert abs(test_random_restored - test_random_saved) < 1e-10, "RNG state mismatch!" + +# Continue simulation for another 50 steps +# Use append mode to continue the trajectory +reporter_restart_continued = ts.TrajectoryReporter( + filenames=trajectory_file_restart, + state_frequency=10, + state_kwargs={"save_velocities": True}, # Save velocities (momenta) for comparison + trajectory_kwargs={"mode": "a"}, # Append mode to continue existing trajectory +) + +# IMPORTANT: When integrate() is called, it internally calls initialize_state() which +# clones the state. The clone() method should preserve the RNG state, but we verify +# that the RNG state is correctly set before calling integrate. +rng_state_before_integrate = restored_state.rng.get_state() + +state_after_100_restart = ts.integrate( + system=restored_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, # Additional 50 steps + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_restart_continued, +) + +# Check RNG state after 100 steps (restart) +test_random_after_100_restart = torch.randn( + 1, generator=state_after_100_restart.rng +).item() +print(f"After 100 steps (restart) - Test random: {test_random_after_100_restart:.6f}") + +reporter_restart_continued.close() + +print(f"Completed restart simulation: 50 + 50 = 100 steps total") + + +# %% [markdown] +""" +## Part 2: Run 100 Steps Continuously for Comparison + +Now let's run a continuous simulation that matches the restart scenario: run 50 steps, +then continue for another 50 steps (without saving/restoring in between). This ensures +we're comparing apples to apples: +""" + +# %% +# Run continuous simulation: 50 steps, then 50 more steps +trajectory_file_continuous = "continuous_trajectory.h5" +reporter_continuous = ts.TrajectoryReporter( + filenames=trajectory_file_continuous, + state_frequency=10, + state_kwargs={"save_velocities": True}, # Save velocities (momenta) for comparison +) + +# Create a fresh initial state with the same seed +initial_state_continuous = ts.initialize_state( + si_atoms, device=torch.device("cpu"), dtype=torch.float64 +) +initial_state_continuous.rng = seed # Same seed as before + +# Run first 50 steps (matching the restart scenario) +print("Running first 50 steps of continuous simulation...") +state_after_50_continuous = ts.integrate( + system=initial_state_continuous, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_continuous, +) + +# Check RNG state after 50 steps +rng_state_after_50_continuous = state_after_50_continuous.rng.get_state() +test_random_after_50_continuous = torch.randn( + 1, generator=state_after_50_continuous.rng +).item() +print(f"After 50 steps - Test random: {test_random_after_50_continuous:.6f}") + +# Continue for another 50 steps (without saving/restoring) +# Use append mode to continue the same trajectory +reporter_continuous_continued = ts.TrajectoryReporter( + filenames=trajectory_file_continuous, + state_frequency=10, + state_kwargs={"save_velocities": True}, + trajectory_kwargs={"mode": "a"}, # Append mode +) + +print("Running second 50 steps of continuous simulation...") +state_after_100_continuous = ts.integrate( + system=state_after_50_continuous, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, # Additional 50 steps + temperature=300, + timestep=0.001, + trajectory_reporter=reporter_continuous_continued, +) + +reporter_continuous_continued.close() + +# Check RNG state after 100 steps +test_random_after_100_continuous = torch.randn( + 1, generator=state_after_100_continuous.rng +).item() +print(f"After 100 steps - Test random: {test_random_after_100_continuous:.6f}") +print(f"Completed continuous simulation: 50 + 50 = 100 steps total") + + +# %% [markdown] +""" +## Part 3: Compare Trajectories + +Let's compare the trajectories from the restarted simulation and the continuous +simulation to verify they are identical: +""" + +# %% +# Load both trajectories +with ts.TorchSimTrajectory(trajectory_file_restart, mode="r") as traj_restart: + positions_restart = traj_restart.get_array("positions") + steps_restart = traj_restart.get_steps("positions") + velocities_restart = traj_restart.get_array("velocities") + +with ts.TorchSimTrajectory(trajectory_file_continuous, mode="r") as traj_continuous: + positions_continuous = traj_continuous.get_array("positions") + steps_continuous = traj_continuous.get_steps("positions") + velocities_continuous = traj_continuous.get_array("velocities") + +print(f"Restart trajectory: {len(steps_restart)} frames") +print(f"Continuous trajectory: {len(steps_continuous)} frames") + +# Compare positions at matching steps +# Both should have states at steps 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100 +matching_steps = sorted(set(steps_restart) & set(steps_continuous)) +print(f"\nMatching steps: {matching_steps}") + +# Compare positions and momenta at each matching step +max_pos_diff = 0.0 +max_mom_diff = 0.0 +all_match = True + +for step in matching_steps: + idx_restart = steps_restart.tolist().index(step) + idx_continuous = steps_continuous.tolist().index(step) + + # Convert numpy arrays to torch tensors for comparison + pos_restart = torch.tensor(positions_restart[idx_restart]) + pos_continuous = torch.tensor(positions_continuous[idx_continuous]) + vel_restart = torch.tensor(velocities_restart[idx_restart]) + vel_continuous = torch.tensor(velocities_continuous[idx_continuous]) + + pos_diff = torch.max(torch.abs(pos_restart - pos_continuous)).item() + vel_diff = torch.max(torch.abs(vel_restart - vel_continuous)).item() + + max_pos_diff = max(max_pos_diff, pos_diff) + max_mom_diff = max(max_mom_diff, vel_diff) + + # Check if they match exactly (within floating point precision) + if not torch.allclose(pos_restart, pos_continuous, atol=1e-10, rtol=1e-10): + print(f" Step {step}: Position mismatch! Max diff: {pos_diff:.2e}") + all_match = False + if not torch.allclose(vel_restart, vel_continuous, atol=1e-10, rtol=1e-10): + print(f" Step {step}: Velocity mismatch! Max diff: {vel_diff:.2e}") + all_match = False + +print(f"\nMaximum position difference: {max_pos_diff:.2e}") +print(f"Maximum velocity difference: {max_mom_diff:.2e}") + +if all_match: + print("\n✓ SUCCESS: Restarted and continuous trajectories match exactly!") +else: + print( + "\n✗ WARNING: Trajectories differ - this may indicate an issue with state saving/restoration" + ) + + +# %% [markdown] +""" +## Part 4: Simplified State Saving with torch.save + +For convenience, you can save the entire state object directly using `torch.save()`. +Since `torch.save()` uses pickle, the `torch.Generator` will be saved along with +everything else automatically. + +**Note for PyTorch 2.6+**: PyTorch 2.6 changed the default `weights_only` parameter +in `torch.load()` from `False` to `True` for security. When loading checkpoints that +contain `torch.Generator` objects, you need to set `weights_only=False`. This is safe +when loading your own checkpoints, but be cautious when loading files from untrusted +sources as it can result in arbitrary code execution. +""" + +# %% +# Simplified approach: save everything together +# Create a fresh state for demonstration +demo_state = ts.integrate( + system=initial_state.clone(), + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=25, + temperature=300, + timestep=0.001, +) + +# Save the entire state dict - Generator is included automatically +from dataclasses import asdict + +state_dict = asdict(demo_state) +torch.save(state_dict, "demo_state.pt") + +# Restore +# Note: PyTorch 2.6+ defaults to weights_only=True for security, which doesn't allow +# loading Generator objects. Since we're loading our own checkpoint, we set +# weights_only=False. Alternatively, you can use torch.serialization.add_safe_globals() +loaded_dict = torch.load("demo_state.pt", weights_only=False) + +# Reconstruct MDState - the Generator is restored automatically +restored_demo = MDState(**loaded_dict) + +# Verify restoration +print(f"Original state energy: {demo_state.energy.item():.6f} eV") +print(f"Restored state energy: {restored_demo.energy.item():.6f} eV") +print(f"Positions match: {torch.allclose(demo_state.positions, restored_demo.positions)}") +print(f"Momenta match: {torch.allclose(demo_state.momenta, restored_demo.momenta)}") +print(f"RNG restored: {restored_demo.rng is not None}") + + +# %% [markdown] +""" +### When to Save RNG State Separately + +The approach above works great when using `torch.save()` (which uses pickle). However, +you may need to save RNG state separately if: + +1. **Using non-pickle formats**: If you're saving to HDF5, JSON, or other formats that + don't support pickling, you'll need to extract the RNG state using `get_state()` and + save it separately. + +2. **Device portability**: If you need to restore to a different device, saving the + state tensor separately gives you more control. + +3. **Explicit documentation**: Some workflows prefer explicit RNG state handling for + clarity and debugging. + +For most use cases with `torch.save()`, the simple approach above is sufficient. +""" + + +# %% [markdown] +""" +## Key Takeaways + +1. **RNG State is Critical**: For stochastic integrators (Langevin, NPT with barostat), + you must save and restore the RNG state. With `torch.save()`, the Generator is + pickled automatically, but you can also save the state separately using `get_state()` + and `set_state()` if needed (e.g., for non-pickle formats or device portability). + +2. **Complete State Saving**: Save all relevant state variables including positions, + momenta, cell parameters, energy, and forces. + +3. **Trajectory Continuity**: When resuming, use append mode (`trajectory_kwargs={"mode": "a"}`) + in `TrajectoryReporter` to continue the existing trajectory file. + +4. **Verification**: Always compare restarted trajectories to continuous runs to ensure + reproducibility. + +5. **Deterministic Integrators**: For deterministic integrators (NVE, NVT Nosé-Hoover), + you don't need to save RNG state, but it's still good practice for consistency. + +## Best Practices + +- Save checkpoints regularly during long simulations +- Include metadata (step number, simulation parameters) with saved states +- Verify reproducibility by comparing trajectories +- Use the same seed and device when restoring states +- Consider saving to a format that's easy to inspect (HDF5 via TorchSimTrajectory) + +For more information on reproducibility in TorchSim, see the +[reproducibility documentation](../../../docs/user/reproducibility.md). + +## Troubleshooting: Trajectories Don't Match + +If your restarted and continuous trajectories don't match exactly, check: + +1. **RNG State Preservation**: Verify that the RNG state is correctly restored before + calling `integrate()`. The RNG state must be set on the state object before it's + passed to `integrate()`, as `integrate()` will clone the state internally. + +2. **Use the Simplified Approach**: If you're having issues with manual RNG state + management, try using the simplified approach from Part 4 (saving everything with + `torch.save()`), which automatically preserves the RNG state. + +3. **Check Initial Step**: When using append mode, ensure the trajectory reporter + correctly detects the last step. The `integrate()` function should automatically + detect this and start from the correct step. + +4. **Verify State Restoration**: Print and compare key state variables (positions, + momenta, RNG state) before continuing the simulation to ensure they match what + was saved. +""" From a4defdac3157dae89c6f2ff0037e3c8850eb47f7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 15 Mar 2026 08:28:08 -0400 Subject: [PATCH 2/6] fix bug in integrators that always advanced rng. Add in tutorial showing saving statedict and restarting same as running 100 steps --- .gitignore | 3 + docs/user/reproducibility.md | 23 +- .../reproducible_restart_tutorial.py | 413 ++++-------------- torch_sim/integrators/npt.py | 31 +- torch_sim/integrators/nve.py | 10 +- torch_sim/integrators/nvt.py | 31 +- 6 files changed, 143 insertions(+), 368 deletions(-) diff --git a/.gitignore b/.gitignore index b0bdd821..b43f2fef 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,9 @@ docs/reference/torch_sim.* *.hdf5 *.traj +# ignore torch.save outputs +*.pt + # coverage coverage.xml .coverage* diff --git a/docs/user/reproducibility.md b/docs/user/reproducibility.md index 8b348e8f..8c797d11 100644 --- a/docs/user/reproducibility.md +++ b/docs/user/reproducibility.md @@ -79,17 +79,34 @@ Because TorchSim runs batched simulations, all systems in a batch share a single If strict reproducibility is required, keep your batching setup fixed. -### Serialising the RNG state +### Serialising state for reproducible restarts -If you wish to be able to resume a session and ensure determinism you need to persist and reload the `torch.Generator` state. This can be done using `torch.save()` and `torch.Generator().set_state()`: +To resume a simulation and ensure determinism you need to persist and reload the complete state, including the `torch.Generator` RNG state. The simplest approach is to save the full state dict with `torch.save()`: ```python +from dataclasses import asdict +from torch_sim.integrators import MDState + # save +torch.save(asdict(state), "checkpoint.pt") + +# restore (weights_only=False needed for torch.Generator in PyTorch 2.6+) +restored = MDState(**torch.load("checkpoint.pt", weights_only=False)) +``` + +This captures positions, momenta, forces, energy, cell, and the `torch.Generator` in a single file. Since `torch.save()` uses pickle, the generator is serialised automatically. + +> **Pickle caveat**: The `torch.Generator` object in the dict requires `weights_only=False` and may not unpickle across PyTorch versions. For portable checkpoints, save the tensors normally and extract the RNG state as a plain `uint8` tensor via `get_state()` — this loads with `weights_only=True` and is version-safe: + +```python +# save RNG state as a plain uint8 tensor (no pickle needed) rng_state = state.rng.get_state() torch.save(rng_state, "rng_state.pt") # restore gen = torch.Generator(device=state.device) -gen.set_state(torch.load("rng_state.pt")) +gen.set_state(torch.load("rng_state.pt", weights_only=True)) state.rng = gen ``` + +See the [reproducible restart tutorial](../../examples/tutorials/reproducible_restart_tutorial.py) for a complete worked example. diff --git a/examples/tutorials/reproducible_restart_tutorial.py b/examples/tutorials/reproducible_restart_tutorial.py index b514886c..fc855032 100644 --- a/examples/tutorials/reproducible_restart_tutorial.py +++ b/examples/tutorials/reproducible_restart_tutorial.py @@ -1,7 +1,7 @@ # %% # /// script # dependencies = [ -# "torch_sim_atomistic[mace, io]" +# "torch_sim_atomistic[io]" # ] # /// @@ -11,50 +11,37 @@ # Reproducible Restarts from Stopped Simulations This tutorial demonstrates how to save and restore simulation state to enable -reproducible restarts from stopped simulations. This is essential for long-running -simulations that may need to be paused and resumed, or for checkpointing workflows. +reproducible restarts. We run 50 steps of MD, save the state (including RNG state), +resume for another 50 steps, and verify the result is identical to 100 uninterrupted +steps. -## Introduction - -When running molecular dynamics simulations, you may need to: -- Pause a simulation and resume it later -- Create checkpoints for long-running simulations -- Ensure that a restarted simulation produces identical results to a continuous run - -To achieve reproducible restarts, you must save not only the atomic positions, velocities, -and other state variables, but also the random number generator (RNG) state. This is -especially important for stochastic integrators like Langevin dynamics, which use random -numbers for both initial momenta sampling and per-step stochastic noise. - -## Key Concepts - -1. **State Saving**: Save the complete simulation state including positions, velocities, - momenta, cell parameters, and RNG state -2. **RNG State**: The `torch.Generator` state must be saved separately using - `get_state()` and restored using `set_state()` -3. **Trajectory Comparison**: Verify that restarted simulations produce identical - trajectories to continuous runs +For stochastic integrators like Langevin dynamics, you must save the random number +generator (RNG) state alongside positions, momenta, and other state variables. +Without it, the stochastic noise will differ on restart and the trajectory will diverge. """ # %% [markdown] """ -## Setup: Initial System and Model - -Let's start by creating a simple system and model for our demonstration: +## Setup """ # %% +from dataclasses import asdict +from pathlib import Path + import torch import torch_sim as ts from ase.build import bulk +from torch_sim.integrators import MDState from torch_sim.models.lennard_jones import LennardJonesModel -# Set up deterministic mode for reproducibility -# Note: This seeds the global RNG, but we'll also seed SimState.rng explicitly +# All generated files go in this directory +restart_dir = Path("restart_files") +restart_dir.mkdir(exist_ok=True) + seed = 42 torch.manual_seed(seed) -# Create a Lennard-Jones model lj_model = LennardJonesModel( sigma=2.0, epsilon=0.1, @@ -62,37 +49,36 @@ dtype=torch.float64, ) -# Create a silicon crystal structure si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) -# Initialize state and seed the RNG initial_state = ts.initialize_state( si_atoms, device=torch.device("cpu"), dtype=torch.float64 ) -initial_state.rng = seed # Critical: seed the SimState RNG for reproducibility - -print(f"Initial state has {initial_state.n_atoms} atoms") -print(f"RNG device: {initial_state.rng.device}") +initial_state.rng = seed # seed the SimState RNG for reproducibility +print(f"Initial state: {initial_state.n_atoms} atoms") # %% [markdown] """ -## Part 1: Run 50 Steps, Save State, and Resume +## Part 1: Run 50 Steps, Save State, Resume for 50 More + +We save the complete state with `asdict()` + `torch.save()`. Since `torch.save()` +uses pickle, the `torch.Generator` (RNG) is included automatically — no need to +save it separately. -First, we'll run 50 steps of MD, save the complete state (including RNG state), -then resume for another 50 steps: +**PyTorch 2.6+**: You must pass `weights_only=False` to `torch.load()` when loading +checkpoints that contain `torch.Generator` objects. """ # %% -# Run first 50 steps with trajectory reporting -trajectory_file_restart = "restart_trajectory.h5" +# Run first 50 steps +trajectory_file_restart = str(restart_dir / "restart_trajectory.h5") reporter_restart = ts.TrajectoryReporter( filenames=trajectory_file_restart, - state_frequency=10, # Save state every 10 steps - state_kwargs={"save_velocities": True}, # Save velocities (momenta) for comparison + state_frequency=10, + state_kwargs={"save_velocities": True}, ) -# Run 50 steps state_after_50 = ts.integrate( system=initial_state.clone(), model=lj_model, @@ -102,209 +88,95 @@ timestep=0.001, trajectory_reporter=reporter_restart, ) - reporter_restart.close() -# Check RNG state after 50 steps (before saving) -test_random_after_50_restart = torch.randn(1, generator=state_after_50.rng).item() -print(f"After 50 steps (restart) - Test random: {test_random_after_50_restart:.6f}") - -# Save the complete state including RNG state -state_save_file = "saved_state.pt" -rng_state_save_file = "saved_rng_state.pt" - -# Save the RNG state separately (as recommended in reproducibility docs) -rng_state = state_after_50.rng.get_state() -torch.save(rng_state, rng_state_save_file) - -# Save the state object (this doesn't include the RNG generator itself) -# We'll need to restore the RNG state separately -torch.save( - { - "positions": state_after_50.positions, - "momenta": state_after_50.momenta, - "cell": state_after_50.cell, - "atomic_numbers": state_after_50.atomic_numbers, - "masses": state_after_50.masses, - "pbc": state_after_50.pbc, - "system_idx": state_after_50.system_idx, - "energy": state_after_50.energy, - "forces": state_after_50.forces, - }, - state_save_file, -) - -print(f"Saved state after 50 steps") -print(f"State file: {state_save_file}") -print(f"RNG state file: {rng_state_save_file}") - +# Save the complete state (including RNG) in one file +checkpoint_file = str(restart_dir / "checkpoint.pt") +torch.save(asdict(state_after_50), checkpoint_file) +print(f"Saved checkpoint after 50 steps to {checkpoint_file}") # %% [markdown] """ -Now let's restore the state and continue for another 50 steps: +Now restore the state and continue for another 50 steps: """ # %% -# Load the saved state -# Note: PyTorch 2.6+ defaults to weights_only=True. Since we're loading our own -# checkpoints, we use weights_only=False to allow loading Generator objects. -saved_data = torch.load(state_save_file, weights_only=False) -rng_state_loaded = torch.load(rng_state_save_file, weights_only=False) - -# Reconstruct the state from saved data -# We need to create an MDState since we have momenta, forces, and energy -from torch_sim.integrators import MDState +# Load checkpoint (weights_only=False needed for torch.Generator in PyTorch 2.6+) +loaded = torch.load(checkpoint_file, weights_only=False) +restored_state = MDState(**loaded) -restored_state = MDState( - positions=saved_data["positions"], - momenta=saved_data["momenta"], - cell=saved_data["cell"], - atomic_numbers=saved_data["atomic_numbers"], - masses=saved_data["masses"], - pbc=saved_data["pbc"], - system_idx=saved_data["system_idx"], - energy=saved_data["energy"], - forces=saved_data["forces"], -) +# Verify RNG was restored +assert torch.equal(restored_state.rng.get_state(), state_after_50.rng.get_state()) +print(f"Restored state: {restored_state.n_atoms} atoms, RNG matches ✓") -# Restore the RNG state - this is critical for reproducibility! -gen = torch.Generator(device=restored_state.device) -gen.set_state(rng_state_loaded) -restored_state.rng = gen - -# Verify RNG state was restored correctly -# Draw a random number to verify the RNG state matches -test_random_restored = torch.randn(1, generator=restored_state.rng).item() -print(f"Restored state with {restored_state.n_atoms} atoms") -print(f"Restored RNG device: {restored_state.rng.device}") -print(f"Test random from restored RNG: {test_random_restored:.6f}") - -# Verify the RNG state matches what we saved -# (We need to reload the saved state to compare) -saved_rng_check = torch.Generator(device=restored_state.device) -saved_rng_check.set_state(rng_state_loaded) -test_random_saved = torch.randn(1, generator=saved_rng_check).item() -print(f"Test random from saved RNG state: {test_random_saved:.6f}") -assert abs(test_random_restored - test_random_saved) < 1e-10, "RNG state mismatch!" - -# Continue simulation for another 50 steps -# Use append mode to continue the trajectory +# Continue for another 50 steps (append to existing trajectory) reporter_restart_continued = ts.TrajectoryReporter( filenames=trajectory_file_restart, state_frequency=10, - state_kwargs={"save_velocities": True}, # Save velocities (momenta) for comparison - trajectory_kwargs={"mode": "a"}, # Append mode to continue existing trajectory + state_kwargs={"save_velocities": True}, + trajectory_kwargs={"mode": "a"}, ) -# IMPORTANT: When integrate() is called, it internally calls initialize_state() which -# clones the state. The clone() method should preserve the RNG state, but we verify -# that the RNG state is correctly set before calling integrate. -rng_state_before_integrate = restored_state.rng.get_state() - state_after_100_restart = ts.integrate( system=restored_state, model=lj_model, integrator=ts.Integrator.nvt_langevin, - n_steps=50, # Additional 50 steps + n_steps=50, temperature=300, timestep=0.001, trajectory_reporter=reporter_restart_continued, ) - -# Check RNG state after 100 steps (restart) -test_random_after_100_restart = torch.randn( - 1, generator=state_after_100_restart.rng -).item() -print(f"After 100 steps (restart) - Test random: {test_random_after_100_restart:.6f}") - reporter_restart_continued.close() - -print(f"Completed restart simulation: 50 + 50 = 100 steps total") - +print(f"Completed restart simulation: 50 + 50 = 100 steps") # %% [markdown] """ ## Part 2: Run 100 Steps Continuously for Comparison - -Now let's run a continuous simulation that matches the restart scenario: run 50 steps, -then continue for another 50 steps (without saving/restoring in between). This ensures -we're comparing apples to apples: """ # %% -# Run continuous simulation: 50 steps, then 50 more steps -trajectory_file_continuous = "continuous_trajectory.h5" +trajectory_file_continuous = str(restart_dir / "continuous_trajectory.h5") reporter_continuous = ts.TrajectoryReporter( filenames=trajectory_file_continuous, state_frequency=10, - state_kwargs={"save_velocities": True}, # Save velocities (momenta) for comparison + state_kwargs={"save_velocities": True}, ) -# Create a fresh initial state with the same seed initial_state_continuous = ts.initialize_state( si_atoms, device=torch.device("cpu"), dtype=torch.float64 ) -initial_state_continuous.rng = seed # Same seed as before +initial_state_continuous.rng = seed -# Run first 50 steps (matching the restart scenario) -print("Running first 50 steps of continuous simulation...") -state_after_50_continuous = ts.integrate( +state_after_100_continuous = ts.integrate( system=initial_state_continuous, model=lj_model, integrator=ts.Integrator.nvt_langevin, - n_steps=50, + n_steps=100, temperature=300, timestep=0.001, trajectory_reporter=reporter_continuous, ) - -# Check RNG state after 50 steps -rng_state_after_50_continuous = state_after_50_continuous.rng.get_state() -test_random_after_50_continuous = torch.randn( - 1, generator=state_after_50_continuous.rng -).item() -print(f"After 50 steps - Test random: {test_random_after_50_continuous:.6f}") - -# Continue for another 50 steps (without saving/restoring) -# Use append mode to continue the same trajectory -reporter_continuous_continued = ts.TrajectoryReporter( - filenames=trajectory_file_continuous, - state_frequency=10, - state_kwargs={"save_velocities": True}, - trajectory_kwargs={"mode": "a"}, # Append mode -) - -print("Running second 50 steps of continuous simulation...") -state_after_100_continuous = ts.integrate( - system=state_after_50_continuous, - model=lj_model, - integrator=ts.Integrator.nvt_langevin, - n_steps=50, # Additional 50 steps - temperature=300, - timestep=0.001, - trajectory_reporter=reporter_continuous_continued, -) - -reporter_continuous_continued.close() - -# Check RNG state after 100 steps -test_random_after_100_continuous = torch.randn( - 1, generator=state_after_100_continuous.rng -).item() -print(f"After 100 steps - Test random: {test_random_after_100_continuous:.6f}") -print(f"Completed continuous simulation: 50 + 50 = 100 steps total") - +reporter_continuous.close() +print(f"Completed continuous simulation: 100 steps") # %% [markdown] """ ## Part 3: Compare Trajectories -Let's compare the trajectories from the restarted simulation and the continuous -simulation to verify they are identical: +Both runs started from the same initial state and seed. The restarted run saved and +restored the RNG state at step 50. If everything is correct, the trajectories should +match exactly: """ # %% -# Load both trajectories +# Compare final RNG states +rng_match = torch.equal( + state_after_100_restart.rng.get_state(), + state_after_100_continuous.rng.get_state(), +) +print(f"Final RNG states match: {rng_match}") + +# Compare trajectories frame by frame with ts.TorchSimTrajectory(trajectory_file_restart, mode="r") as traj_restart: positions_restart = traj_restart.get_array("positions") steps_restart = traj_restart.get_steps("positions") @@ -315,24 +187,17 @@ steps_continuous = traj_continuous.get_steps("positions") velocities_continuous = traj_continuous.get_array("velocities") -print(f"Restart trajectory: {len(steps_restart)} frames") -print(f"Continuous trajectory: {len(steps_continuous)} frames") - -# Compare positions at matching steps -# Both should have states at steps 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100 matching_steps = sorted(set(steps_restart) & set(steps_continuous)) -print(f"\nMatching steps: {matching_steps}") +print(f"Comparing {len(matching_steps)} frames at steps: {matching_steps}") -# Compare positions and momenta at each matching step max_pos_diff = 0.0 -max_mom_diff = 0.0 +max_vel_diff = 0.0 all_match = True for step in matching_steps: idx_restart = steps_restart.tolist().index(step) idx_continuous = steps_continuous.tolist().index(step) - # Convert numpy arrays to torch tensors for comparison pos_restart = torch.tensor(positions_restart[idx_restart]) pos_continuous = torch.tensor(positions_continuous[idx_continuous]) vel_restart = torch.tensor(velocities_restart[idx_restart]) @@ -340,11 +205,9 @@ pos_diff = torch.max(torch.abs(pos_restart - pos_continuous)).item() vel_diff = torch.max(torch.abs(vel_restart - vel_continuous)).item() - max_pos_diff = max(max_pos_diff, pos_diff) - max_mom_diff = max(max_mom_diff, vel_diff) + max_vel_diff = max(max_vel_diff, vel_diff) - # Check if they match exactly (within floating point precision) if not torch.allclose(pos_restart, pos_continuous, atol=1e-10, rtol=1e-10): print(f" Step {step}: Position mismatch! Max diff: {pos_diff:.2e}") all_match = False @@ -352,137 +215,37 @@ print(f" Step {step}: Velocity mismatch! Max diff: {vel_diff:.2e}") all_match = False -print(f"\nMaximum position difference: {max_pos_diff:.2e}") -print(f"Maximum velocity difference: {max_mom_diff:.2e}") - -if all_match: - print("\n✓ SUCCESS: Restarted and continuous trajectories match exactly!") -else: - print( - "\n✗ WARNING: Trajectories differ - this may indicate an issue with state saving/restoration" - ) - - -# %% [markdown] -""" -## Part 4: Simplified State Saving with torch.save - -For convenience, you can save the entire state object directly using `torch.save()`. -Since `torch.save()` uses pickle, the `torch.Generator` will be saved along with -everything else automatically. - -**Note for PyTorch 2.6+**: PyTorch 2.6 changed the default `weights_only` parameter -in `torch.load()` from `False` to `True` for security. When loading checkpoints that -contain `torch.Generator` objects, you need to set `weights_only=False`. This is safe -when loading your own checkpoints, but be cautious when loading files from untrusted -sources as it can result in arbitrary code execution. -""" - -# %% -# Simplified approach: save everything together -# Create a fresh state for demonstration -demo_state = ts.integrate( - system=initial_state.clone(), - model=lj_model, - integrator=ts.Integrator.nvt_langevin, - n_steps=25, - temperature=300, - timestep=0.001, +assert all_match, ( + f"Restarted and continuous trajectories differ! " + f"Max position difference: {max_pos_diff:.2e}, max velocity difference: {max_vel_diff:.2e}" ) - -# Save the entire state dict - Generator is included automatically -from dataclasses import asdict - -state_dict = asdict(demo_state) -torch.save(state_dict, "demo_state.pt") - -# Restore -# Note: PyTorch 2.6+ defaults to weights_only=True for security, which doesn't allow -# loading Generator objects. Since we're loading our own checkpoint, we set -# weights_only=False. Alternatively, you can use torch.serialization.add_safe_globals() -loaded_dict = torch.load("demo_state.pt", weights_only=False) - -# Reconstruct MDState - the Generator is restored automatically -restored_demo = MDState(**loaded_dict) - -# Verify restoration -print(f"Original state energy: {demo_state.energy.item():.6f} eV") -print(f"Restored state energy: {restored_demo.energy.item():.6f} eV") -print(f"Positions match: {torch.allclose(demo_state.positions, restored_demo.positions)}") -print(f"Momenta match: {torch.allclose(demo_state.momenta, restored_demo.momenta)}") -print(f"RNG restored: {restored_demo.rng is not None}") - - -# %% [markdown] -""" -### When to Save RNG State Separately - -The approach above works great when using `torch.save()` (which uses pickle). However, -you may need to save RNG state separately if: - -1. **Using non-pickle formats**: If you're saving to HDF5, JSON, or other formats that - don't support pickling, you'll need to extract the RNG state using `get_state()` and - save it separately. - -2. **Device portability**: If you need to restore to a different device, saving the - state tensor separately gives you more control. - -3. **Explicit documentation**: Some workflows prefer explicit RNG state handling for - clarity and debugging. - -For most use cases with `torch.save()`, the simple approach above is sufficient. -""" - +print("\n✓ Restarted and continuous trajectories match exactly.") # %% [markdown] """ ## Key Takeaways -1. **RNG State is Critical**: For stochastic integrators (Langevin, NPT with barostat), - you must save and restore the RNG state. With `torch.save()`, the Generator is - pickled automatically, but you can also save the state separately using `get_state()` - and `set_state()` if needed (e.g., for non-pickle formats or device portability). - -2. **Complete State Saving**: Save all relevant state variables including positions, - momenta, cell parameters, energy, and forces. - -3. **Trajectory Continuity**: When resuming, use append mode (`trajectory_kwargs={"mode": "a"}`) - in `TrajectoryReporter` to continue the existing trajectory file. - -4. **Verification**: Always compare restarted trajectories to continuous runs to ensure - reproducibility. - -5. **Deterministic Integrators**: For deterministic integrators (NVE, NVT Nosé-Hoover), - you don't need to save RNG state, but it's still good practice for consistency. - -## Best Practices +1. **Save with `asdict()` + `torch.save()`**: This captures everything — positions, + momenta, forces, energy, cell, and the `torch.Generator` RNG state — in a single + checkpoint file. -- Save checkpoints regularly during long simulations -- Include metadata (step number, simulation parameters) with saved states -- Verify reproducibility by comparing trajectories -- Use the same seed and device when restoring states -- Consider saving to a format that's easy to inspect (HDF5 via TorchSimTrajectory) +2. **Restore with `MDState(**torch.load(...))`**: The `torch.Generator` is unpickled + automatically, so the RNG state is restored without any extra steps. -For more information on reproducibility in TorchSim, see the -[reproducibility documentation](../../../docs/user/reproducibility.md). +3. **Use append mode** (`trajectory_kwargs={"mode": "a"}`) in `TrajectoryReporter` + to continue an existing trajectory file. -## Troubleshooting: Trajectories Don't Match +4. **Pickle caveats**: The `torch.Generator` object in the checkpoint requires pickle + (`weights_only=False`) and may not load across PyTorch versions. For portable + checkpoints, save tensors normally and use `state.rng.get_state()` to extract the + RNG state as a plain `uint8` tensor that works with `weights_only=True`. -If your restarted and continuous trajectories don't match exactly, check: - -1. **RNG State Preservation**: Verify that the RNG state is correctly restored before - calling `integrate()`. The RNG state must be set on the state object before it's - passed to `integrate()`, as `integrate()` will clone the state internally. - -2. **Use the Simplified Approach**: If you're having issues with manual RNG state - management, try using the simplified approach from Part 4 (saving everything with - `torch.save()`), which automatically preserves the RNG state. +5. **Verify**: Always compare restarted trajectories to continuous runs. +""" -3. **Check Initial Step**: When using append mode, ensure the trajectory reporter - correctly detects the last step. The `integrate()` function should automatically - detect this and start from the correct step. +# %% +# Cleanup +import shutil -4. **Verify State Restoration**: Print and compare key state variables (positions, - momenta, RNG state) before continuing the simulation to ensure they match what - was saved. -""" +shutil.rmtree(restart_dir) +print(f"Cleaned up {restart_dir}/") diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 9b1ed810..094a7deb 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -574,17 +574,15 @@ def npt_langevin_init( model_output = model(state) # Initialize momenta if not provided - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) # Initialize cell parameters reference_cell = state.cell.clone() @@ -1375,16 +1373,17 @@ def npt_nose_hoover_init( KE_cell = (cell_momentum.squeeze(-1) ** 2) / (2 * cell_mass) # Initialize momenta - momenta = kwargs.get( - "momenta", - initialize_momenta( + momenta = kwargs.get("momenta") + if momenta is None: + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT_tensor, state.rng, - ), - ) + ) # Compute total DOF for thermostat initialization and a zero KE placeholder dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim @@ -2353,17 +2352,15 @@ def npt_crescale_init( model_output = model(state) # Initialize momenta if not provided - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) # Create the initial state return NPTCRescaleState.from_state( diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index ca724126..07f3064b 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -47,17 +47,15 @@ def nve_init( """ model_output = model(state) - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) return MDState.from_state( state, diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 10f2509a..8e74bf85 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -117,17 +117,15 @@ def nvt_langevin_init( """ model_output = model(state) - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) return MDState.from_state( state, momenta=momenta, @@ -303,12 +301,13 @@ def nvt_nose_hoover_init( atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) model_output = model(state) - momenta = kwargs.get( - "momenta", - initialize_momenta( + momenta = kwargs.get("momenta") + if momenta is None: + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT_tensor, state.rng - ), - ) + ) # Calculate initial kinetic energy per system KE = ts.calc_kinetic_energy( @@ -600,17 +599,15 @@ def nvt_vrescale_init( """ model_output = model(state) - momenta = getattr( - state, - "momenta", - initialize_momenta( + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( state.positions, state.masses, state.system_idx, kT, state.rng, - ), - ) + ) return NVTVRescaleState.from_state( state, From ddc88f1790d72f53400c3e74f1fbd7fda6963af0 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 15 Mar 2026 08:31:28 -0400 Subject: [PATCH 3/6] address #485 --- tests/test_state.py | 29 +++++++++++++++++++++++++++++ torch_sim/state.py | 8 ++------ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 6b40e68e..891d53aa 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -348,6 +348,35 @@ def test_initialize_state_from_state(ar_supercell_sim_state: SimState) -> None: assert state.cell.shape == ar_supercell_sim_state.cell.shape +def test_initialize_state_from_list_of_states_with_multiple_systems( + si_double_sim_state: SimState, fe_supercell_sim_state: SimState +) -> None: + """Test initialize_state with list of states that have n_systems > 1.""" + # This should work now that we've removed the arbitrary n_systems == 1 constraint + concatenated = ts.initialize_state([si_double_sim_state, fe_supercell_sim_state]) + + # Should have 3 systems total (2 from si_double + 1 from fe) + assert concatenated.n_systems == 3 + assert concatenated.cell.shape[0] == 3 + + # Check system indices are correct + fe_atoms = fe_supercell_sim_state.n_atoms + expected_system_indices = torch.cat( + [ + si_double_sim_state.system_idx, + torch.full( + (fe_atoms,), 2, dtype=torch.int64, device=fe_supercell_sim_state.device + ), + ] + ) + assert torch.all(concatenated.system_idx == expected_system_indices) + + # Verify we can slice back to original states + assert torch.allclose(concatenated[0].positions, si_double_sim_state[0].positions) + assert torch.allclose(concatenated[1].positions, si_double_sim_state[1].positions) + assert torch.allclose(concatenated[2].positions, fe_supercell_sim_state.positions) + + def test_initialize_state_from_atoms(si_atoms: "Atoms") -> None: """Test conversion from ASE Atoms to SimState.""" state = ts.initialize_state([si_atoms], DEVICE, torch.float64) diff --git a/torch_sim/state.py b/torch_sim/state.py index 52b97e79..f34bb5af 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1225,12 +1225,8 @@ def initialize_state( if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): system: list[SimState] = typing.cast("list[SimState]", system) - if not all(state.n_systems == 1 for state in system): - raise ValueError( - "When providing a list of states, to the initialize_state function, " - "all states must have n_systems == 1. To fix this, you can split the " - "states into individual states with the split_state function." - ) + if len(system) == 0: + raise ValueError("Cannot initialize state from an empty list.") return ts.concatenate_states(system) converters = [ From 11795ff7f5f75314f334ba8bc9cc5810d6746989 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 15 Mar 2026 08:46:16 -0400 Subject: [PATCH 4/6] address #379 --- examples/tutorials/hybrid_swap_tutorial.py | 16 +++++------- tests/test_monte_carlo.py | 20 +++++++++++++++ torch_sim/monte_carlo.py | 29 +++++++++++++++++++--- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index c7fc6362..9a7a2ede 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -107,6 +107,10 @@ class HybridSwapMCState(SwapMCState, MDState): ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 ) + def __post_init__(self) -> None: + """Initialize HybridSwapMCState and ensure last_permutation is set.""" + super().__post_init__() + # %% [markdown] """ @@ -127,16 +131,8 @@ class HybridSwapMCState(SwapMCState, MDState): state.rng = 42 md_state = ts.nvt_langevin_init(state=state, model=mace_model, kT=kT) -# Initialize swap Monte Carlo state -swap_state = ts.swap_mc_init(state=md_state, model=mace_model) - -# Create hybrid state combining both -hybrid_state = HybridSwapMCState( - **md_state.attributes, - last_permutation=torch.arange( - md_state.n_atoms, device=md_state.device, dtype=torch.long - ), -) +# Create hybrid state from MD state +hybrid_state = HybridSwapMCState(**md_state.attributes) # %% [markdown] diff --git a/tests/test_monte_carlo.py b/tests/test_monte_carlo.py index ab0e1bc8..ee8feef6 100644 --- a/tests/test_monte_carlo.py +++ b/tests/test_monte_carlo.py @@ -191,6 +191,26 @@ def test_monte_carlo_integration( assert torch.all(orig_counts == result_counts) +def test_swap_mc_state_default_last_permutation( + batched_diverse_state: ts.SimState, +) -> None: + """Test that SwapMCState initializes last_permutation to identity if not provided.""" + from torch_sim.monte_carlo import SwapMCState + + state = SwapMCState( + positions=batched_diverse_state.positions, + masses=batched_diverse_state.masses, + cell=batched_diverse_state.cell, + pbc=batched_diverse_state.pbc, + atomic_numbers=batched_diverse_state.atomic_numbers, + system_idx=batched_diverse_state.system_idx, + energy=torch.zeros(batched_diverse_state.n_systems), + ) + assert state.last_permutation is not None + expected_identity = torch.arange(batched_diverse_state.n_atoms, device=DEVICE) + assert torch.equal(state.last_permutation, expected_identity) + + def test_swap_mc_state_attributes(): """Test SwapMCState class structure and inheritance.""" from torch_sim.state import SimState diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 149bd406..04dfde31 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -16,7 +16,7 @@ ... mc_state = ts.swap_mc_step(model, mc_state, kT=0.1 * units.energy) """ -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -24,6 +24,15 @@ from torch_sim.state import SimState +# Sentinel value for uninitialized last_permutation +_UNINITIALIZED_PERMUTATION = torch.empty(0, dtype=torch.long) + + +def _create_uninitialized_permutation() -> torch.Tensor: + """Create a sentinel tensor for uninitialized last_permutation.""" + return _UNINITIALIZED_PERMUTATION.clone() + + @dataclass(kw_only=True) class SwapMCState(SimState): """State for Monte Carlo simulations with swap moves. @@ -35,15 +44,28 @@ class SwapMCState(SimState): Attributes: energy (torch.Tensor): Energy of the system with shape [batch_size] last_permutation (torch.Tensor): Last permutation applied to the system, - with shape [n_atoms], tracking the moves made for analysis or reversal + with shape [n_atoms], tracking the moves made for analysis or reversal. + If not provided, will be initialized to identity permutation + (torch.arange(n_atoms)) in __post_init__. """ energy: torch.Tensor - last_permutation: torch.Tensor + last_permutation: torch.Tensor = field( + default_factory=_create_uninitialized_permutation + ) _atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 + def __post_init__(self) -> None: + """Initialize SwapMCState and set default last_permutation if needed.""" + super().__post_init__() + # Check if last_permutation is the sentinel (empty tensor) + if self.last_permutation.numel() == 0: + self.last_permutation = torch.arange( + self.n_atoms, device=self.device, dtype=torch.long + ) + def generate_swaps(state: SimState, rng: torch.Generator | None = None) -> torch.Tensor: """Generate atom swaps for a given batched system. @@ -209,7 +231,6 @@ def swap_mc_init( atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, energy=model_output["energy"], - last_permutation=torch.arange(state.n_atoms, device=state.device), _constraints=state.constraints, ) From 3b49e6fca80a3e4435eaabf11c156845a44fe889 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 15 Mar 2026 08:56:50 -0400 Subject: [PATCH 5/6] direct batch 2 batch graphpes --- torch_sim/models/graphpes_framework.py | 68 +++++++++++++------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index ab8af45d..800fe819 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -27,7 +27,7 @@ try: from graph_pes import AtomicGraph, GraphPESModel - from graph_pes.atomic_graph import PropertyKey, to_batch + from graph_pes.atomic_graph import PropertyKey from graph_pes.models import load_model except ImportError as exc: @@ -63,40 +63,38 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra Returns: AtomicGraph object representing the batched structures """ - graphs = [] - - for sys_idx in range(state.n_systems): - system_mask = state.system_idx == sys_idx - R = state.positions[system_mask] - Z = state.atomic_numbers[system_mask] - cell = state.row_vector_cell[sys_idx] - # graph-pes models internally trim the neighbor list to the - # model's cutoff value. To ensure no strange edge effects whereby - # edges that are exactly `cutoff` long are included/excluded, - # we bump cutoff + 1e-5 up slightly - - # Create system_idx for this single system (all atoms belong to system 0) - system_idx_single = torch.zeros(R.shape[0], dtype=torch.long, device=R.device) - nl, _system_mapping, shifts = torchsim_nl( - R, cell, state.pbc, cutoff + 1e-5, system_idx_single - ) - - atomic_graph = AtomicGraph( - Z=Z.long(), - R=R, - cell=cell, - neighbour_list=nl.long(), - neighbour_cell_offsets=shifts, - properties={}, - cutoff=cutoff.item(), - other={ - "total_charge": torch.tensor(0.0).to(state.device), - "total_spin": torch.tensor(0.0).to(state.device), - }, - ) - graphs.append(atomic_graph) - - return to_batch(graphs) + # graph-pes models internally trim the neighbor list to the + # model's cutoff value. To ensure no strange edge effects whereby + # edges that are exactly `cutoff` long are included/excluded, + # we bump cutoff + 1e-5 up slightly + nl, _system_mapping, shifts = torchsim_nl( + state.positions, + state.row_vector_cell, + state.pbc, + cutoff + 1e-5, + state.system_idx, + ) + n_atoms_per_system = torch.bincount(state.system_idx) + ptr = torch.zeros(state.n_systems + 1, dtype=torch.long, device=state.device) + ptr[1:] = n_atoms_per_system.cumsum(dim=0) + n_sys = state.n_systems + total_charge = torch.zeros(n_sys, device=state.device) + total_spin = torch.zeros(n_sys, device=state.device) + return AtomicGraph( + Z=state.atomic_numbers.long(), + R=state.positions, + cell=state.row_vector_cell, + neighbour_list=nl.long(), + neighbour_cell_offsets=shifts, + properties={}, + cutoff=cutoff.item(), + other={ + "total_charge": total_charge, + "total_spin": total_spin, + }, + batch=state.system_idx, + ptr=ptr, + ) class GraphPESWrapper(ModelInterface): From 44c3bbdaa0b7abe1bb5cc3a24a9c59a397159060 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 15 Mar 2026 09:06:14 -0400 Subject: [PATCH 6/6] Apply suggestion from @CompRhys Signed-off-by: Rhys Goodall --- torch_sim/state.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index f34bb5af..b3cb8f07 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1225,8 +1225,6 @@ def initialize_state( if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): system: list[SimState] = typing.cast("list[SimState]", system) - if len(system) == 0: - raise ValueError("Cannot initialize state from an empty list.") return ts.concatenate_states(system) converters = [