diff --git a/pyproject.toml b/pyproject.toml index 90de3c4a..a56b49bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ [project.optional-dependencies] test = [ "torch-sim-atomistic[io,symmetry,vesin]", + "physical-validation>=1.0.5", "platformdirs>=4.0.0", "psutil>=7.0.0", "pymatgen>=2025.6.14", @@ -139,8 +140,11 @@ check-filenames = true ignore-words-list = ["convertor"] # codespell:ignore convertor [tool.pytest.ini_options] -addopts = ["-p no:warnings"] +addopts = ["-p no:warnings", "-m not physical_validation"] testpaths = ["tests"] +markers = [ + "physical_validation: long-running physical validation tests (run with: pytest -m physical_validation)", +] [tool.uv] # make these dependencies mutually exclusive since they use incompatible e3nn versions diff --git a/tests/conftest.py b/tests/conftest.py index 79b39b48..750edea9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,21 @@ torch.set_num_threads(4) + +def pytest_addoption(parser): + parser.addoption( + "--validation-plots", + action="store_true", + default=False, + help="Save physical validation plots to tests/physical_validation_data/plots/", + ) + parser.addoption( + "--clean-validation-data", + action="store_true", + default=False, + help="Delete saved physical validation data before running tests", + ) + DEVICE = torch.device("cpu") DTYPE = torch.float64 diff --git a/tests/test_integrators.py b/tests/test_integrators.py index dd9b49d0..0ada0f06 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -155,6 +155,66 @@ def test_npt_langevin( assert pos_diff > 0.0001 # Systems should remain separated +def test_npt_langevin_strain( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + n_steps = 200 + dt = torch.tensor(0.001, dtype=DTYPE) * MetalUnits.time + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + external_pressure = torch.tensor(10.0, dtype=DTYPE) * MetalUnits.pressure + alpha = 1 * dt + cell_alpha = 10 * dt + b_tau = 30 * dt + + ar_double_sim_state.rng = 42 + state = ts.npt_langevin_strain_init( + state=ar_double_sim_state, + model=lj_model, + dt=dt, + kT=kT, + alpha=alpha, + cell_alpha=cell_alpha, + b_tau=b_tau, + ) + + # Check strain state shape + assert state.cell_positions.shape == (2,) # scalar strain per system + + energies = [] + temperatures = [] + for _step in range(n_steps): + state = ts.npt_langevin_strain_step( + state=state, + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + ) + + temp = ts.calc_kT( + masses=state.masses, momenta=state.momenta, system_idx=state.system_idx + ) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + temperatures_tensor = torch.stack(temperatures) + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + assert len(energies_list[0]) == n_steps + + mean_temps = torch.mean(temperatures_tensor, dim=0) + for mean_temp in mean_temps: + assert abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 + + # Cell reconstruction is consistent + assert torch.allclose(state.cell, state.current_cell) + + def test_npt_langevin_multi_kt( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ): @@ -339,7 +399,7 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone temperatures_list = [t.tolist() for t in temperatures_tensor.T] assert torch.allclose( temperatures_tensor[-1], - torch.tensor([300.0096, 299.7024], dtype=dtype), + torch.tensor([305.6400, 305.4556], dtype=dtype), ) energies_tensor = torch.stack(energies) @@ -728,7 +788,7 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone temperatures_list = [t.tolist() for t in temperatures_tensor.T] assert torch.allclose( temperatures_tensor[-1], - torch.tensor([298.2752, 297.9444], dtype=dtype), + torch.tensor([283.1162, 313.1624], dtype=dtype), ) energies_tensor = torch.stack(energies) @@ -1023,19 +1083,19 @@ def test_compute_cell_force_atoms_per_system(): atomic_numbers=torch.ones(72, dtype=torch.long), stress=torch.zeros((2, 3, 3)), reference_cell=torch.eye(3).repeat(2, 1, 1), - cell_positions=torch.ones((2, 3, 3)), - cell_velocities=torch.zeros((2, 3, 3)), + cell_positions=torch.zeros(2, 3), + cell_velocities=torch.zeros(2, 3), cell_masses=torch.ones(2), alpha=torch.ones(2), cell_alpha=torch.ones(2), b_tau=torch.ones(2), ) - # Get forces and compare ratio - cell_force = _compute_cell_force(state, torch.tensor(0.0), torch.tensor([1.0, 1.0])) - force_ratio = ( - torch.diagonal(cell_force[1]).mean() / torch.diagonal(cell_force[0]).mean() - ) + # Get forces and compare ratio (per-dimension force) + P_ext = torch.zeros(2, 3) + cell_force = _compute_cell_force(state, P_ext, torch.tensor([1.0, 1.0])) + # Check the first dimension's force ratio + force_ratio = cell_force[1, 0] / cell_force[0, 0] # Force ratio should match atom ratio (8:1) with the fix assert abs(force_ratio - 8.0) / 8.0 < 0.1 diff --git a/tests/test_nbody.py b/tests/test_nbody.py index e235cd62..5da5a45c 100644 --- a/tests/test_nbody.py +++ b/tests/test_nbody.py @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None: result = build_triplets(edge_index, n_atoms) - assert result["trip_in"].device == dev - assert result["trip_out"].device == dev - assert result["center_atom"].device == dev + assert result["trip_in"].device.type == dev.type + assert result["trip_out"].device.type == dev.type + assert result["center_atom"].device.type == dev.type @pytest.mark.parametrize( @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None: internal_cell_offsets, ) - assert result["quad_c_to_a_edge"].device == dev - assert result["quad_d_to_b_trip_idx"].device == dev - assert result["d_to_b_edge"].device == dev - assert result["c_to_a_edge"].device == dev + assert result["quad_c_to_a_edge"].device.type == dev.type + assert result["quad_d_to_b_trip_idx"].device.type == dev.type + assert result["d_to_b_edge"].device.type == dev.type + assert result["c_to_a_edge"].device.type == dev.type def test_build_triplets_jit_script() -> None: diff --git a/tests/test_physical_validation.py b/tests/test_physical_validation.py new file mode 100644 index 00000000..a25504c1 --- /dev/null +++ b/tests/test_physical_validation.py @@ -0,0 +1,746 @@ +"""Physical validation tests for torch-sim MD integrators. + +Uses the physical_validation library (https://github.com/shirtsgroup/physical_validation) +to verify that integrators produce physically correct results. These tests require CUDA +and are long-running. Run with: + + pytest -m physical_validation -v + +Options: + --validation-plots Save plots to tests/physical_validation_data/plots/ + --clean-validation-data Delete saved validation data before running + +Run a specific integrator: + + pytest -m physical_validation -v -k "nvt_langevin" + +Tested integrators: + + NVT: + - nvt_langevin + - nvt_nose_hoover + - nvt_vrescale + + NPT: + - npt_langevin (independent per-dimension strain, like LAMMPS couple=none) + - npt_langevin_strain (isotropic logarithmic strain) + - npt_nose_hoover + - npt_isotropic_crescale + - npt_anisotropic_crescale + +Clean up saved data programmatically: + + from tests.test_physical_validation import clean_validation_data + clean_validation_data() +""" + +import shutil +import warnings +from pathlib import Path + +import numpy as np +import pytest +import torch +from ase.build import bulk +from numpy.typing import NDArray + +import torch_sim as ts +from torch_sim.integrators.npt import npt_crescale_average_anisotropic_step +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.units import MetalUnits + + +physical_validation = pytest.importorskip("physical_validation") + +if not torch.cuda.is_available(): + pytest.skip("CUDA not available", allow_module_level=True) + +# --------------------------------------------------------------------------- +# Device & dtype — CUDA required +# --------------------------------------------------------------------------- +DEVICE = torch.device("cuda") +DTYPE = torch.float64 + +# --------------------------------------------------------------------------- +# LJ Argon parameters +# --------------------------------------------------------------------------- +SIGMA = 3.405 +EPSILON = 0.0104 +CUTOFF = 2.5 * SIGMA + +# --------------------------------------------------------------------------- +# Simulation parameters (matched to fast_integrator_tests_batch) +# --------------------------------------------------------------------------- +TIMESTEP_PS = 0.005 +N_STEPS_NVT = 10_000 +N_STEPS_NPT = 10_000 +N_EQUILIBRATION_NVT = 4_000 +N_EQUILIBRATION_NPT = 5_000 +LOG_EVERY = 5 + +# Ensemble check temperatures and pressures (matched to fast_integrator_tests_batch) +TEMPERATURES = [58.3, 60.0] +EXTERNAL_PRESSURE = 0.0 +PRESSURE_SWEEP_TEMP = 60.0 +PRESSURE_SWEEP_BAR = 90.0 +PRESSURE_SWEEP_EVA3 = PRESSURE_SWEEP_BAR * float(MetalUnits.pressure) + +# Physical validation thresholds (in sigma units) +KE_SIGMA_WARNING = 2.0 +KE_SIGMA_THRESHOLD = 3.0 +ENSEMBLE_SIGMA_WARNING = 2.0 +ENSEMBLE_SIGMA_THRESHOLD = 3.0 + +# Data & plot directories +DATA_DIR = Path(__file__).parent / "physical_validation_data" +PLOTS_DIR = DATA_DIR / "plots" + +RunData = dict[str, NDArray[np.floating] | float | int] + +torch.set_num_threads(4) + + +# --------------------------------------------------------------------------- +# Cleanup utility +# --------------------------------------------------------------------------- +def clean_validation_data() -> None: + """Delete all saved physical validation data and plots.""" + if DATA_DIR.exists(): + shutil.rmtree(DATA_DIR) + print(f"Removed {DATA_DIR}") + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- +def _to_kT(temperature_K: float) -> float: + return temperature_K * float(MetalUnits.temperature) + + +def _to_dt(timestep_ps: float) -> float: + return timestep_ps * float(MetalUnits.time) + + +def _save_run_data(data: RunData, label: str) -> Path: + """Save run data to a .npz file and return the path.""" + DATA_DIR.mkdir(parents=True, exist_ok=True) + path = DATA_DIR / f"{label}.npz" + np.savez(path, **data) + return path + + +def _get_plot_path(request: pytest.FixtureRequest, name: str) -> str | None: + """Return plot file path if --validation-plots is enabled, else None.""" + if not request.config.getoption("--validation-plots", default=False): + return None + PLOTS_DIR.mkdir(parents=True, exist_ok=True) + return str(PLOTS_DIR / f"{name}.png") + + +def _pressure_to_bar(p_eva3: float) -> float: + """Convert eV/Ang^3 to bar.""" + return p_eva3 / float(MetalUnits.pressure) + + +# --------------------------------------------------------------------------- +# Helpers: unit data, model, structure +# --------------------------------------------------------------------------- +def _make_unit_data() -> physical_validation.data.UnitData: + """Create UnitData for torch-sim's MetalUnits system.""" + return physical_validation.data.UnitData( + kb=float(MetalUnits.temperature), # k_B in eV/K = 8.617e-5 + energy_str="eV", + energy_conversion=96.485, # Convert to kJ/mol + length_str="Ang", + length_conversion=1e-1, # Convert to nm + volume_str="Ang^3", + volume_conversion=1e-3, # Convert to nm^3 + temperature_str="K", + temperature_conversion=1.0, + pressure_str="bar", + pressure_conversion=1.0, + ) + + +def _make_lj_model(*, compute_stress: bool = False) -> LennardJonesModel: + """Create a Lennard-Jones model for Argon.""" + return LennardJonesModel( + sigma=SIGMA, + epsilon=EPSILON, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=compute_stress, + cutoff=CUTOFF, + ) + + +def _make_ar_supercell( + repeat: tuple[int, int, int] = (8, 8, 8), +) -> ts.SimState: + """Create an FCC Argon supercell SimState.""" + atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat(repeat) + return ts.io.atoms_to_state(atoms, DEVICE, DTYPE) + + +# --------------------------------------------------------------------------- +# Generic NVT runner +# --------------------------------------------------------------------------- +def _run_nvt( + integrator_name: str, + sim_state: ts.SimState, + model: LennardJonesModel, + temperature: float, + timestep_ps: float = TIMESTEP_PS, + n_steps: int = N_STEPS_NVT, + n_equilibration: int = N_EQUILIBRATION_NVT, + log_every: int = LOG_EVERY, + seed: int = 42, +) -> RunData: + """Run an NVT simulation with the specified integrator.""" + kT = _to_kT(temperature) + dt = _to_dt(timestep_ps) + natoms = int(sim_state.positions.shape[0]) + + sim_state = sim_state.clone() + sim_state.rng = seed + + # Initialize (params matched to fast_integrator_tests_batch) + if integrator_name == "nvt_langevin": + state = ts.nvt_langevin_init(sim_state, model, kT=kT) + elif integrator_name == "nvt_nose_hoover": + state = ts.nvt_nose_hoover_init(sim_state, model, kT=kT, dt=dt, tau=10 * dt) + elif integrator_name == "nvt_vrescale": + state = ts.nvt_vrescale_init(sim_state, model, kT=kT) + else: + msg = f"Unknown NVT integrator: {integrator_name}" + raise ValueError(msg) + + def _step(s): + if integrator_name == "nvt_langevin": + return ts.nvt_langevin_step(s, model, dt=dt, kT=kT, gamma=1 / (50 * dt)) + if integrator_name == "nvt_nose_hoover": + return ts.nvt_nose_hoover_step(s, model, dt=dt, kT=kT) + return ts.nvt_vrescale_step(model, s, dt=dt, kT=kT, tau=10 * dt) + + # Equilibration + for _ in range(n_equilibration): + state = _step(state) + + # Production (subsampled every log_every steps) + ke_list, pe_list, total_e_list = [], [], [] + + for i in range(n_steps): + state = _step(state) + if (i + 1) % log_every == 0: + ke = float( + ts.calc_kinetic_energy(masses=state.masses, momenta=state.momenta) + ) + pe = float(state.energy.sum()) + ke_list.append(ke) + pe_list.append(pe) + total_e_list.append(ke + pe) + + cell = sim_state.cell[0].detach().cpu().numpy() + volume = float(np.abs(np.linalg.det(cell))) + + return { + "kinetic_energy": np.array(ke_list), + "potential_energy": np.array(pe_list), + "total_energy": np.array(total_e_list), + "volume": volume, + "masses": sim_state.masses.detach().cpu().numpy(), + "dt_internal": dt, + "natoms": natoms, + "target_temperature": temperature, + "timestep_ps": timestep_ps, + "integrator": integrator_name, + } + + +# --------------------------------------------------------------------------- +# Generic NPT runner +# --------------------------------------------------------------------------- +def _run_npt( + integrator_name: str, + sim_state: ts.SimState, + model: LennardJonesModel, + temperature: float, + external_pressure: float = 0.0, + timestep_ps: float = TIMESTEP_PS, + n_steps: int = N_STEPS_NPT, + n_equilibration: int = N_EQUILIBRATION_NPT, + log_every: int = LOG_EVERY, + seed: int = 42, +) -> RunData: + """Run an NPT simulation with the specified integrator.""" + kT = _to_kT(temperature) + dt = torch.tensor(_to_dt(timestep_ps), device=DEVICE, dtype=DTYPE) + ext_p = torch.tensor(external_pressure, device=DEVICE, dtype=DTYPE) + natoms = int(sim_state.positions.shape[0]) + + sim_state = sim_state.clone() + sim_state.rng = seed + + # Initialize (params matched to fast_integrator_tests_batch) + if integrator_name == "npt_langevin": + state = ts.npt_langevin_init( + sim_state, model, kT=kT, dt=dt, + alpha=1 / (5 * dt), cell_alpha=1 / (30 * dt), b_tau=300 * dt, + ) + elif integrator_name == "npt_langevin_strain": + state = ts.npt_langevin_strain_init( + sim_state, model, kT=kT, dt=dt, + alpha=1 / (5 * dt), cell_alpha=1 / (30 * dt), b_tau=300 * dt, + ) + elif integrator_name == "npt_nose_hoover": + state = ts.npt_nose_hoover_init( + sim_state, model, kT=kT, dt=dt, + t_tau=10 * dt, b_tau=100 * dt, + ) + elif integrator_name == "npt_isotropic_crescale": + state = ts.npt_crescale_init( + sim_state, model, kT=kT, dt=dt, + tau_p=3 * dt, + isothermal_compressibility=1e-6 / MetalUnits.pressure, + ) + elif integrator_name == "npt_anisotropic_crescale": + state = ts.npt_crescale_init( + sim_state, model, kT=kT, dt=dt, + tau_p=3 * dt, + isothermal_compressibility=1e-6 / MetalUnits.pressure, + ) + else: + msg = f"Unknown NPT integrator: {integrator_name}" + raise ValueError(msg) + + def _step(s): + if integrator_name == "npt_langevin": + return ts.npt_langevin_step( + s, model, dt=dt, kT=kT, external_pressure=ext_p, + ) + if integrator_name == "npt_langevin_strain": + return ts.npt_langevin_strain_step( + s, model, dt=dt, kT=kT, external_pressure=ext_p, + ) + if integrator_name == "npt_nose_hoover": + return ts.npt_nose_hoover_step( + s, model, dt=dt, kT=kT, external_pressure=ext_p, + ) + if integrator_name == "npt_anisotropic_crescale": + return npt_crescale_average_anisotropic_step( + s, model, dt=dt, kT=kT, external_pressure=ext_p, tau=1 * dt, + ) + return ts.npt_crescale_isotropic_step( + s, model, dt=dt, kT=kT, external_pressure=ext_p, tau=1 * dt, + ) + + # Equilibration + for _ in range(n_equilibration): + state = _step(state) + + # Production (subsampled every log_every steps) + ke_list, pe_list, total_e_list = [], [], [] + volume_list = [] + + for i in range(n_steps): + state = _step(state) + if (i + 1) % log_every == 0: + ke = float( + ts.calc_kinetic_energy(masses=state.masses, momenta=state.momenta) + ) + pe = float(state.energy.sum()) + cell = state.cell[0].detach().cpu().numpy() + vol = float(np.abs(np.linalg.det(cell))) + ke_list.append(ke) + pe_list.append(pe) + total_e_list.append(ke + pe) + volume_list.append(vol) + + return { + "kinetic_energy": np.array(ke_list), + "potential_energy": np.array(pe_list), + "total_energy": np.array(total_e_list), + "volumes": np.array(volume_list), + "masses": sim_state.masses.detach().cpu().numpy(), + "dt_internal": float(dt), + "natoms": natoms, + "target_temperature": temperature, + "external_pressure": external_pressure, + "timestep_ps": timestep_ps, + "integrator": integrator_name, + } + + +# --------------------------------------------------------------------------- +# SimulationData builders +# --------------------------------------------------------------------------- +def _build_nvt_simulation_data( + run_data: RunData, + temperature: float, +) -> physical_validation.data.SimulationData: + """Build a physical_validation SimulationData from NVT run results.""" + units = _make_unit_data() + + system = physical_validation.data.SystemData( + natoms=run_data["natoms"], + nconstraints=0, + ndof_reduction_tra=3, + ndof_reduction_rot=0, + mass=run_data["masses"], + ) + + ensemble_data = physical_validation.data.EnsembleData( + ensemble="NVT", + natoms=run_data["natoms"], + volume=run_data["volume"], + temperature=temperature, + ) + + observables = physical_validation.data.ObservableData( + kinetic_energy=run_data["kinetic_energy"], + potential_energy=run_data["potential_energy"], + total_energy=run_data["total_energy"], + ) + + return physical_validation.data.SimulationData( + units=units, + dt=run_data["timestep_ps"], + system=system, + ensemble=ensemble_data, + observables=observables, + ) + + +def _build_npt_simulation_data( + run_data: RunData, + temperature: float, + pressure: float, +) -> physical_validation.data.SimulationData: + """Build a physical_validation SimulationData from NPT run results.""" + units = _make_unit_data() + + system = physical_validation.data.SystemData( + natoms=run_data["natoms"], + nconstraints=0, + ndof_reduction_tra=3, + ndof_reduction_rot=0, + mass=run_data["masses"], + ) + + ensemble_data = physical_validation.data.EnsembleData( + ensemble="NPT", + natoms=run_data["natoms"], + pressure=pressure, + temperature=temperature, + ) + + observables = physical_validation.data.ObservableData( + kinetic_energy=run_data["kinetic_energy"], + potential_energy=run_data["potential_energy"], + total_energy=run_data["total_energy"], + volume=run_data["volumes"], + ) + + return physical_validation.data.SimulationData( + units=units, + dt=run_data["timestep_ps"], + system=system, + ensemble=ensemble_data, + observables=observables, + ) + + +# =========================================================================== +# Session fixture: cleanup saved data +# =========================================================================== +@pytest.fixture(autouse=True, scope="session") +def _manage_validation_data(request): + """Clean data directory if --clean-validation-data is set.""" + if request.config.getoption("--clean-validation-data", default=False): + clean_validation_data() + yield + + +# =========================================================================== +# Tests: KE distribution (Maxwell-Boltzmann) +# =========================================================================== +@pytest.mark.physical_validation +@pytest.mark.parametrize( + "integrator_name", + ["nvt_langevin", "nvt_nose_hoover", "nvt_vrescale"], +) +def test_nvt_ke_distribution(integrator_name: str, request) -> None: + """Test that KE follows the Maxwell-Boltzmann distribution for NVT.""" + sim_state = _make_ar_supercell(repeat=(8, 8, 8)) + model = _make_lj_model() + temperature = TEMPERATURES[1] + + run_data = _run_nvt( + integrator_name, sim_state, model, temperature=temperature, seed=42, + ) + _save_run_data(run_data, f"{integrator_name}_T{temperature:.1f}K_ke") + + data = _build_nvt_simulation_data(run_data, temperature) + plot_path = _get_plot_path(request, f"{integrator_name}_nvt_ke") + + kwargs = {} + if plot_path: + kwargs["filename"] = plot_path + d_mean, d_width = physical_validation.kinetic_energy.distribution( + data, strict=False, verbosity=0, **kwargs, + ) + + if abs(d_mean) > KE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] KE mean deviation {d_mean:.2f} sigma exceeds " + f"{KE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + if abs(d_width) > KE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] KE width deviation {d_width:.2f} sigma exceeds " + f"{KE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + assert abs(d_mean) < KE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] KE mean deviation {d_mean:.2f} sigma" + ) + assert abs(d_width) < KE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] KE width deviation {d_width:.2f} sigma" + ) + + +@pytest.mark.physical_validation +@pytest.mark.parametrize( + "integrator_name", + [ + "npt_langevin", + "npt_langevin_strain", + "npt_nose_hoover", + "npt_isotropic_crescale", + "npt_anisotropic_crescale", + ], +) +def test_npt_ke_distribution(integrator_name: str, request) -> None: + """Test that KE follows the Maxwell-Boltzmann distribution for NPT.""" + sim_state = _make_ar_supercell(repeat=(8, 8, 8)) + model = _make_lj_model(compute_stress=True) + temperature = TEMPERATURES[1] + + run_data = _run_npt( + integrator_name, sim_state, model, + temperature=temperature, external_pressure=EXTERNAL_PRESSURE, seed=42, + ) + _save_run_data(run_data, f"{integrator_name}_T{temperature:.1f}K_ke") + + # Use NVT builder with mean volume for KE distribution check + run_data_nvt = {**run_data, "volume": float(np.mean(run_data["volumes"]))} + data = _build_nvt_simulation_data(run_data_nvt, temperature) + plot_path = _get_plot_path(request, f"{integrator_name}_npt_ke") + + kwargs = {} + if plot_path: + kwargs["filename"] = plot_path + d_mean, d_width = physical_validation.kinetic_energy.distribution( + data, strict=False, verbosity=0, **kwargs, + ) + + if abs(d_mean) > KE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] KE mean deviation {d_mean:.2f} sigma exceeds " + f"{KE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + if abs(d_width) > KE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] KE width deviation {d_width:.2f} sigma exceeds " + f"{KE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + assert abs(d_mean) < KE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] KE mean deviation {d_mean:.2f} sigma" + ) + assert abs(d_width) < KE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] KE width deviation {d_width:.2f} sigma" + ) + + +# =========================================================================== +# Tests: ensemble validity (Boltzmann weight ratio at two temperatures) +# =========================================================================== +@pytest.mark.physical_validation +@pytest.mark.parametrize( + "integrator_name", + ["nvt_langevin", "nvt_nose_hoover", "nvt_vrescale"], +) +def test_nvt_ensemble_check(integrator_name: str, request) -> None: + """Test NVT ensemble validity at two temperatures.""" + sim_state = _make_ar_supercell(repeat=(8, 8, 8)) + model = _make_lj_model() + + temp_low, temp_high = TEMPERATURES + + run_low = _run_nvt( + integrator_name, sim_state, model, temperature=temp_low, seed=42, + ) + run_high = _run_nvt( + integrator_name, sim_state, model, temperature=temp_high, seed=123, + ) + _save_run_data(run_low, f"{integrator_name}_T{temp_low:.1f}K_ens") + _save_run_data(run_high, f"{integrator_name}_T{temp_high:.1f}K_ens") + + data_low = _build_nvt_simulation_data(run_low, temp_low) + data_high = _build_nvt_simulation_data(run_high, temp_high) + plot_path = _get_plot_path(request, f"{integrator_name}_nvt_ens") + + kwargs = {} + if plot_path: + kwargs["filename"] = plot_path + quantiles = physical_validation.ensemble.check( + data_low, data_high, + total_energy=True, + data_is_uncorrelated=True, + verbosity=0, + **kwargs, + ) + + for i, q in enumerate(quantiles): + if abs(q) > ENSEMBLE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] Ensemble quantile {i} = {q:.2f} sigma exceeds " + f"{ENSEMBLE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + assert abs(q) < ENSEMBLE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] Ensemble quantile {i} = {q:.2f} sigma" + ) + + +@pytest.mark.physical_validation +@pytest.mark.parametrize( + "integrator_name", + [ + "npt_langevin", + "npt_langevin_strain", + "npt_nose_hoover", + "npt_isotropic_crescale", + "npt_anisotropic_crescale", + ], +) +def test_npt_ensemble_check(integrator_name: str, request) -> None: + """Test NPT ensemble validity at two temperatures. + + Uses temperatures both in the solid phase (below LJ Ar melting point ~84K) + to avoid the solid-liquid phase transition which causes non-overlapping + energy distributions in the NPT ensemble. + """ + sim_state = _make_ar_supercell(repeat=(8, 8, 8)) + model = _make_lj_model(compute_stress=True) + + temp_low, temp_high = TEMPERATURES + + run_low = _run_npt( + integrator_name, sim_state, model, + temperature=temp_low, external_pressure=EXTERNAL_PRESSURE, seed=42, + ) + run_high = _run_npt( + integrator_name, sim_state, model, + temperature=temp_high, external_pressure=EXTERNAL_PRESSURE, seed=123, + ) + _save_run_data(run_low, f"{integrator_name}_T{temp_low:.1f}K_ens") + _save_run_data(run_high, f"{integrator_name}_T{temp_high:.1f}K_ens") + + data_low = _build_npt_simulation_data(run_low, temp_low, EXTERNAL_PRESSURE) + data_high = _build_npt_simulation_data(run_high, temp_high, EXTERNAL_PRESSURE) + plot_path = _get_plot_path(request, f"{integrator_name}_npt_ens_temp") + + kwargs = {} + if plot_path: + kwargs["filename"] = plot_path + quantiles = physical_validation.ensemble.check( + data_low, data_high, + total_energy=True, + data_is_uncorrelated=True, + verbosity=0, + **kwargs, + ) + + for i, q in enumerate(quantiles): + if abs(q) > ENSEMBLE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] Ensemble quantile {i} = {q:.2f} sigma exceeds " + f"{ENSEMBLE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + assert abs(q) < ENSEMBLE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] Ensemble quantile {i} = {q:.2f} sigma" + ) + + +# =========================================================================== +# Tests: ensemble validity (Boltzmann weight ratio at two pressures) +# =========================================================================== +@pytest.mark.physical_validation +@pytest.mark.parametrize( + "integrator_name", + [ + "npt_langevin", + "npt_langevin_strain", + "npt_nose_hoover", + "npt_isotropic_crescale", + "npt_anisotropic_crescale", + ], +) +def test_npt_pressure_ensemble_check(integrator_name: str, request) -> None: + """Test NPT ensemble validity at two pressures (fixed temperature).""" + sim_state = _make_ar_supercell(repeat=(8, 8, 8)) + model = _make_lj_model(compute_stress=True) + + p_low = EXTERNAL_PRESSURE + p_high = PRESSURE_SWEEP_EVA3 + + run_low = _run_npt( + integrator_name, sim_state, model, + temperature=PRESSURE_SWEEP_TEMP, external_pressure=p_low, seed=42, + ) + run_high = _run_npt( + integrator_name, sim_state, model, + temperature=PRESSURE_SWEEP_TEMP, external_pressure=p_high, seed=123, + ) + _save_run_data( + run_low, + f"{integrator_name}_T{PRESSURE_SWEEP_TEMP:.1f}K_P0bar", + ) + _save_run_data( + run_high, + f"{integrator_name}_T{PRESSURE_SWEEP_TEMP:.1f}K_P{PRESSURE_SWEEP_BAR:.0f}bar", + ) + + data_low = _build_npt_simulation_data(run_low, PRESSURE_SWEEP_TEMP, p_low) + data_high = _build_npt_simulation_data(run_high, PRESSURE_SWEEP_TEMP, p_high) + plot_path = _get_plot_path(request, f"{integrator_name}_npt_ens_press") + + kwargs = {} + if plot_path: + kwargs["filename"] = plot_path + quantiles = physical_validation.ensemble.check( + data_low, data_high, + total_energy=True, + data_is_uncorrelated=True, + verbosity=0, + **kwargs, + ) + + for i, q in enumerate(quantiles): + if abs(q) > ENSEMBLE_SIGMA_WARNING: + warnings.warn( + f"[{integrator_name}] Pressure ensemble quantile {i} = {q:.2f} sigma " + f"exceeds {ENSEMBLE_SIGMA_WARNING} sigma warning threshold", + stacklevel=1, + ) + assert abs(q) < ENSEMBLE_SIGMA_THRESHOLD, ( + f"[{integrator_name}] Pressure ensemble quantile {i} = {q:.2f} sigma" + ) + + diff --git a/tests/test_state.py b/tests/test_state.py index ed1b3c29..9d31c492 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1096,8 +1096,8 @@ def test_nptlangevinstate_instantiation() -> None: cell_alpha=torch.ones(1), b_tau=torch.ones(1), reference_cell=torch.eye(3).unsqueeze(0), - cell_positions=torch.zeros(1, 3, 3), - cell_velocities=torch.zeros(1, 3, 3), + cell_positions=torch.zeros(1, 3), + cell_velocities=torch.zeros(1, 3), cell_masses=torch.ones(1), ) _check_coercion(state) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index bac5ec98..d43ffd93 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1396,13 +1396,13 @@ def test_build_linked_cell_neighborhood_basic() -> None: def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): n_steps = 50 - dt = torch.tensor(0.001, dtype=DTYPE) + dt = torch.tensor(0.001, dtype=DTYPE) * MetalUnits.time kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature # Same cell state = ts.nvt_langevin_init(state=ar_double_sim_state, model=lj_model, kT=kT) state.positions = tst.pbc_wrap_batched(state.positions, state.cell, state.system_idx) - positions = [state.positions.detach().clone()] + positions = [state.positions.detach().clone()] # for _step in range(n_steps): state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) positions.append(state.positions.detach().clone()) @@ -1446,7 +1446,7 @@ def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJon ) unwrapped_positions = tst.unwrap_positions( wrapped_positions, - state.cell, + torch.stack(cells), state.system_idx, ) assert torch.allclose(unwrapped_positions, positions, atol=1e-4) diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index d9536754..a9d93950 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -39,12 +39,15 @@ ) from torch_sim.integrators.npt import ( NPTLangevinState, + NPTLangevinStrainState, NPTNoseHooverState, npt_crescale_anisotropic_step, npt_crescale_init, npt_crescale_isotropic_step, npt_langevin_init, npt_langevin_step, + npt_langevin_strain_init, + npt_langevin_strain_step, npt_nose_hoover_init, npt_nose_hoover_invariant, npt_nose_hoover_step, @@ -124,6 +127,7 @@ "LBFGSState", "NPTLangevinState", "NPTLangevinState", + "NPTLangevinStrainState", "NPTNoseHooverState", "NPTNoseHooverState", "NVTNoseHooverState", @@ -171,6 +175,8 @@ "npt_langevin_init", "npt_langevin_step", "npt_langevin_step", + "npt_langevin_strain_init", + "npt_langevin_strain_step", "npt_nose_hoover_init", "npt_nose_hoover_init", "npt_nose_hoover_invariant", diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index 950f1664..7ef07245 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -17,33 +17,40 @@ - Langevin barostat integrator :func:`npt.npt_langevin_step` [4, 5] - Nosé-Hoover barostat integrator :func:`npt.npt_nose_hoover_step` from [10] - Isotropic C-Rescale barostat integrator :func:`npt.npt_crescale_isotropic_step` - from [6, 8, 9] + from [6, 8, 9] - C-Rescale barostat integrator :func:`npt.npt_crescale_anisotropic_step` - from [7, 8, 9]. Available implementations include isotropic and - anisotropic cell rescaling, allowing to change cell lengths, and potentially angles - as well. + from [7, 8, 9]. Anisotropic NPT allows to change cell lengths as well as angles. References: [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." The Journal of chemical physics, 126(1), 014101 (2007). + [2] Leimkuhler B, Matthews C.2016 Efficient molecular dynamics using geodesic integration and solvent-solute splitting. Proc. R. Soc. A 472: 20160138 + [3] Martyna, G. J., Tuckerman, M. E., Tobias, D. J., & Klein, M. L. (1996). Explicit reversible integrators for extended systems dynamics. Molecular Physics, 87(5), 1117-1157. + [4] Grønbech-Jensen, N., & Farago, O. (2014). Constant pressure and temperature discrete-time Langevin molecular dynamics. The Journal of chemical physics, 141(19). + [5] LAMMPS: https://docs.lammps.org/fix_press_langevin.html + [6] Bernetti, Mattia, and Giovanni Bussi. "Pressure control using stochastic cell rescaling." The Journal of Chemical Physics 153.11 (2020). + [7] Del Tatto, Vittorio, et al. "Molecular dynamics of solids at constant pressure and stress using anisotropic stochastic cell rescaling." Applied Sciences 12.3 (2022): 1139. + [8] Bussi Anisotropic C-Rescale SimpleMD implementation: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp + [9] Supplementary Information for [6]. + [10]Tuckerman, Mark E., et al. "A Liouville-operator derived measure-preserving integrator for molecular dynamics simulations in the isothermal-isobaric ensemble." Journal of Physics A: Mathematical and General 39.19 (2006): 5629-5651. @@ -74,12 +81,15 @@ from .md import MDState, initialize_momenta, momentum_step, position_step, velocity_verlet from .npt import ( NPTLangevinState, + NPTLangevinStrainState, NPTNoseHooverState, npt_crescale_anisotropic_step, npt_crescale_init, npt_crescale_isotropic_step, npt_langevin_init, npt_langevin_step, + npt_langevin_strain_init, + npt_langevin_strain_step, npt_nose_hoover_init, npt_nose_hoover_invariant, npt_nose_hoover_step, @@ -129,6 +139,7 @@ class Integrator(StrEnum): nvt_langevin = "nvt_langevin" nvt_nose_hoover = "nvt_nose_hoover" npt_langevin = "npt_langevin" + npt_langevin_strain = "npt_langevin_strain" npt_nose_hoover = "npt_nose_hoover" npt_isotropic_crescale = "npt_isotropic_crescale" npt_anisotropic_crescale = "npt_anisotropic_crescale" @@ -168,6 +179,7 @@ class Integrator(StrEnum): Integrator.nvt_langevin: (nvt_langevin_init, nvt_langevin_step), Integrator.nvt_nose_hoover: (nvt_nose_hoover_init, nvt_nose_hoover_step), Integrator.npt_langevin: (npt_langevin_init, npt_langevin_step), + Integrator.npt_langevin_strain: (npt_langevin_strain_init, npt_langevin_strain_step), Integrator.npt_nose_hoover: (npt_nose_hoover_init, npt_nose_hoover_step), Integrator.npt_isotropic_crescale: (npt_crescale_init, npt_crescale_isotropic_step), Integrator.npt_anisotropic_crescale: ( diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 76ff69c1..4a593c04 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -409,7 +409,7 @@ def init_fn( Q = ( kT_batched.unsqueeze(-1) - * torch.square(tau_batched).unsqueeze(-1) ** 2 + * torch.square(tau_batched).unsqueeze(-1) * torch.ones((n_systems, chain_length), dtype=dtype, device=device) ) Q[:, 0] *= degrees_of_freedom diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4a86084c..3fa79d41 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -21,6 +21,7 @@ from torch_sim.integrators.nvt import _vrescale_update from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState +from torch_sim.units import MetalUnits logger = logging.getLogger(__name__) @@ -54,12 +55,15 @@ class and add their own auxiliary variables. @dataclass(kw_only=True) class NPTLangevinState(NPTState): - """State information for an NPT system with Langevin dynamics. + """State for NPT Langevin dynamics with independent per-dimension cell lengths. - This class represents the complete state of a molecular system being integrated - in the NPT (constant particle number, pressure, temperature) ensemble using - Langevin dynamics. In addition to particle positions and momenta, it tracks - cell dimensions and their dynamics for volume fluctuations. + Each spatial dimension has its own logarithmic strain coordinate + εi = ln(Li/Li0), driven by the corresponding diagonal pressure + component P_ii. This is analogous to LAMMPS ``fix press/langevin`` + with ``couple none``. + + With three identical target pressures the sum of forces equals the + isotropic strain force, so the isotropic limit is recovered. Attributes: positions (torch.Tensor): Particle positions [n_particles, n_dim] @@ -72,16 +76,19 @@ class NPTLangevinState(NPTState): system_idx (torch.Tensor): System indices [n_particles] atomic_numbers (torch.Tensor): Atomic numbers [n_particles] stress (torch.Tensor): Stress tensor [n_systems, n_dim, n_dim] - reference_cell (torch.Tensor): Original cell vectors used as reference for - scaling [n_systems, n_dim, n_dim] - cell_positions (torch.Tensor): Cell positions [n_systems, n_dim, n_dim] - cell_velocities (torch.Tensor): Cell velocities [n_systems, n_dim, n_dim] - cell_masses (torch.Tensor): Masses associated with the cell degrees of freedom - shape [n_systems] + reference_cell (torch.Tensor): Original cell [n_systems, d, d] + cell_positions (torch.Tensor): Per-dimension strain εi [n_systems, 3] + cell_velocities (torch.Tensor): dεi/dt [n_systems, 3] + cell_masses (torch.Tensor): Mass for strain DOFs [n_systems] + alpha (torch.Tensor): Particle friction [n_systems] + cell_alpha (torch.Tensor): Cell friction [n_systems] + b_tau (torch.Tensor): Barostat time constant [n_systems] Properties: momenta (torch.Tensor): Particle momenta calculated as velocities*masses with shape [n_particles, n_dimensions] + current_cell (torch.Tensor): Cell reconstructed from strain and reference_cell + volume (torch.Tensor): Current volume from cell determinant n_systems (int): Number of independent systems in the batch device (torch.device): Device on which tensors are stored dtype (torch.dtype): Data type of tensors @@ -93,8 +100,8 @@ class NPTLangevinState(NPTState): # Cell variables reference_cell: torch.Tensor - cell_positions: torch.Tensor - cell_velocities: torch.Tensor + cell_positions: torch.Tensor # (n_systems, 3) per-dimension strain + cell_velocities: torch.Tensor # (n_systems, 3) cell_masses: torch.Tensor _system_attributes = NPTState._system_attributes | { # noqa: SLF001 @@ -107,6 +114,17 @@ class NPTLangevinState(NPTState): "b_tau", } + @property + def current_cell(self) -> torch.Tensor: + """Compute cell from per-dimension strain: cell[i,:] = exp(εi) · ref[i,:].""" + scale = torch.exp(self.cell_positions) # (n_systems, 3) + return scale.unsqueeze(-1) * self.reference_cell + + @property + def volume(self) -> torch.Tensor: + """Current volume from cell determinant.""" + return torch.linalg.det(self.cell) + def _npt_langevin_beta( state: NPTLangevinState, @@ -121,8 +139,6 @@ def _npt_langevin_beta( Args: state (NPTLangevinState): Current NPT state - alpha (torch.Tensor): Friction coefficient, either scalar or - shape [n_systems] kT (torch.Tensor): Temperature in energy units, either scalar or shape [n_systems] dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] @@ -140,10 +156,15 @@ def _npt_langevin_beta( # Map system kT to atoms atom_kT = batch_kT[state.system_idx] + atom_alpha = state.alpha[state.system_idx] + + atom_dt = dt + if dt.ndim == 0: + atom_dt = dt.expand(state.n_systems)[state.system_idx] # Calculate the prefactor for each atom # The standard deviation should be sqrt(2*alpha*kB*T*dt) - prefactor = torch.sqrt(2 * state.alpha * atom_kT * dt) + prefactor = torch.sqrt(2 * atom_alpha * atom_kT * atom_dt) return prefactor.unsqueeze(-1) * noise @@ -153,214 +174,115 @@ def _npt_langevin_cell_beta( kT: torch.Tensor, dt: torch.Tensor, ) -> torch.Tensor: - """Generate random noise for cell fluctuations in NPT dynamics. - - This function creates properly scaled random noise for cell dynamics in NPT - simulations, following the fluctuation-dissipation theorem to ensure correct - thermal sampling of cell degrees of freedom. + """Generate per-dimension noise for cell length fluctuations. Args: - state (NPTLangevinState): Current NPT state - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_systems] - kT (torch.Tensor): System temperature in energy units, either scalar or - with shape [n_systems] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - device (torch.device): Device for tensor operations - dtype (torch.dtype): Data type for tensor operations + state: Current NPT state + kT: Temperature in energy units (scalar or [n_systems]) + dt: Timestep (scalar or [n_systems]) Returns: - torch.Tensor: Scaled random noise for cell dynamics with shape - [n_systems, n_dimensions, n_dimensions] + torch.Tensor: Noise [n_systems, 3] """ - # Generate standard normal distribution (zero mean, unit variance) - noise = _randn_for_state(state, state.cell_positions.shape) - - if kT.ndim == 0: - kT = kT.expand(state.n_systems) - - # Reshape for broadcasting - cell_alpha_expanded = state.cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) - kT = kT.view(-1, 1, 1) # shape: (n_systems, 1, 1) - dt = dt.expand(state.n_systems).view(-1, 1, 1) if dt.ndim == 0 else dt.view(-1, 1, 1) - - # Scale to satisfy the fluctuation-dissipation theorem - # The standard deviation should be sqrt(2*alpha*kB*T*dt) - scaling_factor = torch.sqrt(2.0 * cell_alpha_expanded * kT * dt) - - return scaling_factor * noise + noise = _randn_for_state(state, (state.n_systems, 3)) + batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) + dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) + scaling = torch.sqrt(2.0 * state.cell_alpha * batch_kT * dt_expanded) + return scaling.unsqueeze(-1) * noise def _npt_langevin_cell_position_step( state: NPTLangevinState, dt: torch.Tensor, - pressure_force: torch.Tensor, - kT: torch.Tensor, + strain_force: torch.Tensor, + cell_beta: torch.Tensor, ) -> NPTLangevinState: - """Update the cell position in NPT dynamics. - - This function updates the cell position (effectively the volume) in NPT dynamics - using the current cell velocities, pressure forces, and thermal noise. It - implements the position update part of the Langevin barostat algorithm. + """GJF position step for per-dimension strain εi. Args: - state (NPTLangevinState): Current NPT state - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - pressure_force (torch.Tensor): Pressure force for barostat - [n_systems, n_dim, n_dim] - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - with shape [n_systems] + state: Current NPT state + dt: Timestep + strain_force: F_εi [n_systems, 3] + cell_beta: Noise [n_systems, 3] Returns: - NPTLangevinState: Updated state with new cell positions + Updated state with new cell_positions (strain) """ - # Calculate effective mass term - Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_systems, 1, 1) - - # Ensure parameters have batch dimension - if dt.ndim == 0: - dt = dt.expand(state.n_systems) - - # Reshape for broadcasting - dt_expanded = dt.view(-1, 1, 1) - cell_alpha_expanded = state.cell_alpha.view(-1, 1, 1) + Q_2 = (2 * state.cell_masses).unsqueeze(-1) # (n_systems, 1) + dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) + dt_3 = dt_expanded.unsqueeze(-1) if dt_expanded.ndim > 0 else dt_expanded - # Calculate damping factor for cell position update - cell_b = 1 / (1 + ((cell_alpha_expanded * dt_expanded) / Q_2)) + cell_b = 1 / (1 + (state.cell_alpha.unsqueeze(-1) * dt_3) / Q_2) - # Deterministic velocity contribution - c_1 = cell_b * dt_expanded * state.cell_velocities - - # Force contribution - c_2 = cell_b * dt_expanded * dt_expanded * pressure_force / Q_2 + c_1 = cell_b * dt_3 * state.cell_velocities + c_2 = cell_b * dt_3 * dt_3 * strain_force / Q_2 + c_3 = cell_b * dt_3 * cell_beta / Q_2 - # Random noise contribution (thermal fluctuations) - c_3 = cell_b * dt_expanded * _npt_langevin_cell_beta(state, kT, dt) / Q_2 - - # Update cell positions with all contributions state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 return state def _npt_langevin_cell_velocity_step( state: NPTLangevinState, - F_p_n: torch.Tensor, + F_eps_n: torch.Tensor, dt: torch.Tensor, - pressure_force: torch.Tensor, - kT: torch.Tensor, + strain_force: torch.Tensor, + cell_beta: torch.Tensor, ) -> NPTLangevinState: - """Update the cell velocities in NPT dynamics. - - This function updates the cell velocities using a Langevin-type integrator, - accounting for both deterministic forces from pressure differences and - stochastic thermal noise. It implements the velocity update part of the - Langevin barostat algorithm. + """GJF velocity step for per-dimension strain εi. Args: - state (NPTLangevinState): Current NPT state - F_p_n (torch.Tensor): Initial pressure force with shape - [n_systems, n_dimensions, n_dimensions] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - pressure_force (torch.Tensor): Final pressure force - shape [n_systems, n_dim, n_dim] - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_systems] - kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_systems] + state: Current NPT state + F_eps_n: Initial strain force [n_systems, 3] + dt: Timestep + strain_force: Final strain force [n_systems, 3] + cell_beta: Noise (SAME as in position step) [n_systems, 3] Returns: - NPTLangevinState: Updated state with new cell velocities + Updated state with new cell_velocities """ - # Ensure parameters have batch dimension - if dt.ndim == 0: - dt = dt.expand(state.n_systems) - if kT.ndim == 0: - kT = kT.expand(state.n_systems) + dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) + dt_3 = dt_expanded.unsqueeze(-1) if dt_expanded.ndim > 0 else dt_expanded - # Reshape for broadcasting - need to maintain 3x3 dimensions - dt_expanded = dt.view(-1, 1, 1) # shape: (n_systems, 1, 1) - cell_alpha_expanded = state.cell_alpha.view(-1, 1, 1) # shape: (n_systems, 1, 1) + Q = state.cell_masses.unsqueeze(-1) # (n_systems, 1) + alpha_c = state.cell_alpha.unsqueeze(-1) # (n_systems, 1) + a = (1 - (alpha_c * dt_3) / (2 * Q)) / (1 + (alpha_c * dt_3) / (2 * Q)) + b = 1 / (1 + (alpha_c * dt_3) / (2 * Q)) - # Calculate cell masses per system - reshape to match 3x3 cell matrices - cell_masses_expanded = state.cell_masses.view(-1, 1, 1) # shape: (n_systems, 1, 1) + c_1 = a * state.cell_velocities + c_2 = dt_3 * ((a * F_eps_n) + strain_force) / (2 * Q) + c_3 = b * cell_beta / Q - # These factors come from the Langevin integration scheme - a = (1 - (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) / ( - 1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded - ) - b = 1 / (1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) - - # Calculate the three terms for velocity update - # a will broadcast from (n_systems, 1, 1) to (n_systems, 3, 3) - c_1 = a * state.cell_velocities # Damped old velocity - - # Force contribution (average of initial and final forces) - c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) - - # Generate system-specific cell noise with correct shape (n_systems, 3, 3) - cell_noise = _randn_for_state(state, state.cell_velocities.shape) - - # Calculate thermal noise amplitude - noise_prefactor = torch.sqrt( - 2 * cell_alpha_expanded * kT.view(-1, 1, 1) * dt_expanded - ) - noise_term = noise_prefactor * cell_noise / torch.sqrt(cell_masses_expanded) - - # Random noise contribution - c_3 = b * noise_term - - # Update velocities with all contributions state.cell_velocities = c_1 + c_2 + c_3 return state def _npt_langevin_position_step( state: NPTLangevinState, - L_n: torch.Tensor, # This should be shape (n_systems,) + eps_old: torch.Tensor, dt: torch.Tensor, - kT: torch.Tensor, + particle_beta: torch.Tensor, ) -> NPTLangevinState: - """Update the particle positions in NPT dynamics. + """Update particle positions with per-dimension strain scaling. - This function updates particle positions accounting for both the changing - cell dimensions and the particle velocities/forces. It handles the scaling - of positions due to volume changes as well as the normal position updates - from velocities. + Each component of position is scaled by exp(εi_new - εi_old). Args: - state (NPTLangevinState): Current NPT state - L_n (torch.Tensor): Previous cell length scale with shape [n_systems] - dt: Integration timestep, either scalar or with shape [n_systems] - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - alpha (torch.Tensor | None): Friction coefficient, either scalar or with - shape [n_systems]. + state: Current state (cell_positions already updated) + eps_old: Previous strain [n_systems, 3] + dt: Timestep + particle_beta: Noise [n_particles, n_dim] Returns: - NPTLangevinState: Updated state with new positions + Updated state with new positions """ - # Calculate effective mass term by system - # Map masses to have batch dimension - M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) - - # Calculate new cell length scale (cube root of volume for isotropic scaling) - L_n_new = torch.pow( - state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 - ) # shape: (n_systems,) - - # Map system-specific L_n and L_n_new to atom-level using system indices - # Make sure L_n is the right shape (n_systems,) before indexing - if L_n.ndim != 1 or L_n.shape[0] != state.n_systems: - # If L_n has wrong shape, calculate it again to ensure correct shape - L_n = torch.pow(state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3) - - # Map system-specific values to atoms using system indices - L_n_atoms = L_n[state.system_idx] # shape: (n_atoms,) - L_n_new_atoms = L_n_new[state.system_idx] # shape: (n_atoms,) - - # Calculate damping factor + M_2 = 2 * state.masses.unsqueeze(-1) # (n_atoms, 1) + + # Per-dimension scale factor + scale = torch.exp(state.cell_positions - eps_old) # (n_systems, 3) + scale_atoms = scale[state.system_idx] # (n_atoms, 3) + + # Damping factor alpha_atoms = state.alpha[state.system_idx] dt_atoms = dt if dt.ndim > 0: @@ -368,30 +290,19 @@ def _npt_langevin_position_step( b = 1 / (1 + ((alpha_atoms * dt_atoms) / (2 * state.masses))) - # Scale positions due to cell volume change - c_1 = (L_n_new_atoms / L_n_atoms).unsqueeze(-1) * state.positions + # Scale each position component independently + c_1 = scale_atoms * state.positions # (n_atoms, 3) - # Time step factor with average length scale - c_2 = (2 * L_n_new_atoms / (L_n_new_atoms + L_n_atoms)) * b * dt_atoms + # Time step factor: 2·s/(s+1) per dimension + c_2 = (2 * scale_atoms / (scale_atoms + 1)) * b.unsqueeze(-1) * dt_atoms.unsqueeze(-1) - # Generate atom-specific noise - noise = _randn_for_state(state, state.momenta.shape) - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_systems) - atom_kT = batch_kT[state.system_idx] - - # Calculate noise prefactor according to fluctuation-dissipation theorem - noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) - noise_term = noise_prefactor.unsqueeze(-1) * noise - - # Velocity and force contributions with random noise c_3 = ( - state.velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + noise_term / M_2 + state.velocities + + dt_atoms.unsqueeze(-1) * state.forces / M_2 + + particle_beta / M_2 ) - # Update positions with all contributions - state.set_constrained_positions(c_1 + c_2.unsqueeze(-1) * c_3) + state.set_constrained_positions(c_1 + c_2 * c_3) return state @@ -399,22 +310,20 @@ def _npt_langevin_velocity_step( state: NPTLangevinState, forces: torch.Tensor, dt: torch.Tensor, - kT: torch.Tensor, + particle_beta: torch.Tensor, ) -> NPTLangevinState: """Update the particle velocities in NPT dynamics. This function updates particle velocities using a Langevin-type integrator, - accounting for both deterministic forces and stochastic thermal noise. - It implements the velocity update part of the Langevin thermostat algorithm. + accounting for both deterministic forces and pre-generated thermal noise. Args: state (NPTLangevinState): Current NPT state - forces: Forces on particles + forces: Forces on particles (from before position update) dt: Integration timestep, either scalar or with shape [n_systems] - kT: Target temperature in energy units, either scalar or - with shape [n_systems] - alpha (torch.Tensor | None): Friction coefficient, either scalar or with - shape [n_systems]. + particle_beta (torch.Tensor): Pre-generated GJF noise term β for particle + dynamics. Must be the SAME realization used in the position step. + Shape [n_particles, n_dim] Returns: NPTLangevinState: Updated state with new velocities @@ -439,19 +348,8 @@ def _npt_langevin_velocity_step( # Force contribution (average of initial and final forces) c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2.unsqueeze(-1) - # Generate atom-specific noise - noise = _randn_for_state(state, state.momenta.shape) - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_systems) - atom_kT = batch_kT[state.system_idx] - - # Calculate noise prefactor according to fluctuation-dissipation theorem - noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) - noise_term = noise_prefactor.unsqueeze(-1) * noise - - # Random noise contribution - c_3 = b * noise_term / state.masses.unsqueeze(-1) + # GJF noise term: b * β / m + c_3 = b * particle_beta / state.masses.unsqueeze(-1) # Update momenta (velocities * masses) with all contributions new_velocities = c_1 + c_2 + c_3 @@ -462,63 +360,41 @@ def _npt_langevin_velocity_step( def _compute_cell_force( state: NPTLangevinState, - external_pressure: float | torch.Tensor, - kT: float | torch.Tensor, + external_pressure: torch.Tensor, + kT: torch.Tensor, ) -> torch.Tensor: - """Compute forces on the cell for NPT dynamics. + """Compute per-dimension force on the strain coordinates. + + F_εi = V · (P_ii - P_ext_i) - This function calculates the forces acting on the simulation cell - based on the difference between internal stress and external pressure, - plus a kinetic contribution. These forces drive the volume changes - needed to maintain constant pressure. + where P_ii = -σ_ii + N·kT/V is the ii diagonal pressure component. + The force is in energy units (eV). Args: - state (NPTLangevinState): Current NPT state - external_pressure (torch.Tensor): Target external pressure, either scalar or - tensor with shape [n_systems, n_dimensions, n_dimensions] - kT (torch.Tensor): Temperature in energy units, either scalar or - shape [n_systems] + state: Current NPT state + external_pressure: Target pressure per dimension [3] or [n_systems, 3] + kT: Temperature in energy units (scalar or [n_systems]) Returns: - torch.Tensor: Force acting on the cell [n_systems, n_dim, n_dim] + torch.Tensor: Force per dimension [n_systems, 3] """ - external_pressure = torch.as_tensor( - external_pressure, device=state.device, dtype=state.dtype - ) - kT = torch.as_tensor(kT, device=state.device, dtype=state.dtype) - - # Get current volumes for each batch - volumes = torch.linalg.det(state.cell) # shape: (n_systems,) + volumes = state.volume # (n_systems,) - # Reshape for broadcasting - volumes = volumes.view(-1, 1, 1) # shape: (n_systems, 1, 1) + # Diagonal stress components \sigma_ii + stress_diag = torch.diagonal(state.stress, dim1=-2, dim2=-1) # (n_systems, 3) - # Create pressure tensor (diagonal with external pressure) - if external_pressure.ndim == 0: - # Scalar pressure - create diagonal pressure tensors for each batch - pressure_tensor = external_pressure * torch.eye( - 3, device=state.device, dtype=state.dtype - ) - pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_systems, -1, -1) - else: - # Already a tensor with shape compatible with n_systems - pressure_tensor = external_pressure + # P_ii = -\sigma_ii (virial part) + P_virial_diag = -stress_diag # (n_systems, 3) - # Calculate virials from stress and external pressure - # Internal stress is negative of virial tensor divided by volume - virial = -volumes * (state.stress + pressure_tensor) + # Kinetic contribution per dimension: N·kT/V (target temperature) + batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) + n_atoms = state.n_atoms_per_system.to(dtype=state.dtype) + kinetic_pressure = (n_atoms * batch_kT / volumes).unsqueeze(-1) # (n_systems, 1) - # Add kinetic contribution (kT * Identity) - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_systems) + P_diag = P_virial_diag + kinetic_pressure # (n_systems, 3) - e_kin_per_atom = batch_kT.view(-1, 1, 1) * torch.eye( - 3, device=state.device, dtype=state.dtype - ).unsqueeze(0) - - # Correct implementation with scaling by n_atoms_per_system - return virial + e_kin_per_atom * state.n_atoms_per_system.view(-1, 1, 1) + # F_εi = V · (P_ii - P_ext_i) + return volumes.unsqueeze(-1) * (P_diag - external_pressure) def npt_langevin_init( @@ -532,48 +408,34 @@ def npt_langevin_init( b_tau: float | torch.Tensor | None = None, **_kwargs: Any, ) -> NPTLangevinState: - """Initialize an NPT Langevin state from input data. + """Initialize NPT Langevin state with independent per-dimension cell lengths. - This function creates the initial state for NPT Langevin dynamics, - setting up all necessary variables including particle velocities, - cell parameters, and barostat variables. It computes initial forces - and stress using the provided model. + Each spatial dimension gets its own strain DOF εi = ln(Li/Li0), + driven by the corresponding diagonal pressure component. To seed the RNG set ``state.rng = seed`` before calling. Args: - model (ModelInterface): Neural network model that computes energies, forces, - and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - state (SimState): SimState containing positions, masses, cell, pbc - kT (torch.Tensor): Target temperature in energy units, either scalar or - with shape [n_systems] - dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] - alpha (torch.Tensor, optional): Friction coefficient for particle Langevin - thermostat, either scalar or shape [n_systems]. Defaults to 1/(100*dt). - cell_alpha (torch.Tensor, optional): Friction coefficient for cell Langevin - thermostat, either scalar or shape [n_systems]. Defaults to same as alpha. - b_tau (torch.Tensor, optional): Barostat time constant controlling how quickly - the system responds to pressure differences, either scalar or shape - [n_systems]. Defaults to 1/(1000*dt). + state: SimState containing positions, masses, cell, pbc + model: Model computing energy, forces, stress + kT: Target temperature in energy units + dt: Integration timestep + alpha: Particle friction. Defaults to 1/(5·dt). + cell_alpha: Cell friction. Defaults to 1/(30·dt). + b_tau: Barostat time constant. Defaults to 300·dt. Returns: - NPTLangevinState: Initialized state for NPT Langevin integration containing - all required attributes for particle and cell dynamics - - Notes: - - The model must provide stress tensor calculations for proper pressure coupling + NPTLangevinState with εi = 0 for all dimensions """ device, dtype = model.device, model.dtype - # Set default values if not provided if alpha is None: - alpha = 1.0 / (100 * dt) # Default friction based on timestep + alpha = 1.0 / (5 * dt) if cell_alpha is None: - cell_alpha = alpha # Use same friction for cell by default + cell_alpha = 1.0 / (30 * dt) if b_tau is None: - b_tau = 1 / (1000 * dt) # Default barostat time constant + b_tau = 300 * dt - # Convert all parameters to tensors with correct device and dtype alpha = torch.as_tensor(alpha, device=device, dtype=dtype) cell_alpha = torch.as_tensor(cell_alpha, device=device, dtype=dtype) b_tau = torch.as_tensor(b_tau, device=device, dtype=dtype) @@ -587,10 +449,8 @@ def npt_langevin_init( if b_tau.ndim == 0: b_tau = b_tau.expand(state.n_systems) - # Get model output to initialize forces and stress model_output = model(state) - # Initialize momenta if not provided momenta = getattr(state, "momenta", None) if momenta is None: momenta = initialize_momenta( @@ -601,38 +461,26 @@ def npt_langevin_init( state.rng, ) - # Initialize cell parameters reference_cell = state.cell.clone() + dim = state.positions.shape[1] - # Calculate initial cell_positions (volume) - cell_positions = ( - torch.linalg.det(state.cell).unsqueeze(-1).unsqueeze(-1) - ) # shape: (n_systems, 1, 1) - - # Initialize cell velocities to zero - cell_velocities = torch.zeros((state.n_systems, 3, 3), device=device, dtype=dtype) + # εi = 0 at initialization (V = V₀) + cell_positions = torch.zeros(state.n_systems, dim, device=device, dtype=dtype) + cell_velocities = torch.zeros(state.n_systems, dim, device=device, dtype=dtype) - # Calculate cell masses based on system size and temperature - # This follows standard NPT barostat mass scaling + batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT n_atoms_per_system = torch.bincount(state.system_idx) - batch_kT = ( - kT.expand(state.n_systems) - if isinstance(kT, torch.Tensor) and kT.ndim == 0 - else kT - ) cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau if state.constraints: - # warn if constraints are present msg = ( "Constraints are present in the system. " - "Make sure they are compatible with NPT Langevin dynamics." + "Make sure they are compatible with NPT Langevin dynamics. " "We recommend not using constraints with NPT dynamics for now." ) warnings.warn(msg, UserWarning, stacklevel=3) logger.warning(msg) - # Create the initial state return NPTLangevinState.from_state( state, momenta=momenta, @@ -658,36 +506,57 @@ def npt_langevin_step( kT: float | torch.Tensor, external_pressure: float | torch.Tensor, ) -> NPTLangevinState: - """Perform one complete NPT Langevin dynamics integration step. + r"""Perform one NPT Langevin step with independent per-dimension cell lengths. + + Implements constant-pressure Langevin dynamics based on Gronbech-Jensen & + Farago (2014) [4]_ and the LAMMPS ``fix press/langevin`` scheme [5]_. + + Each spatial dimension *i* has its own logarithmic strain + :math:`\varepsilon_i = \ln(L_i/L_{i,0})` driven by the diagonal + pressure component :math:`P_{ii}`. + + **Per-dimension strain force:** + + .. math:: + + F_{\varepsilon_i} = V \cdot (P_{ii} - P_{\text{ext},i}) + + where :math:`P_{ii} = -\sigma_{ii} + N k_B T / V`. + + With three identical target pressures the sum + :math:`\sum_i F_{\varepsilon_i}` equals the isotropic strain force. - This function implements a modified integration scheme for NPT dynamics, - handling both atomic and cell updates with Langevin thermostats to maintain - constant temperature and pressure. The integration scheme couples particle - motion with cell volume fluctuations. + **Cell reconstruction:** + + .. math:: + + \mathbf{h}_i = e^{\varepsilon_i}\,\mathbf{h}_{i,0} + + **Particle scaling (per component):** + + .. math:: + + r_{k,i} \to e^{\varepsilon_i^{n+1} - \varepsilon_i^n}\, r_{k,i} Args: - model (ModelInterface): Neural network model that computes energies, forces, - and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - state (NPTLangevinState): Current NPT state with particle and cell variables - dt (float | torch.Tensor): Integration timestep, either scalar or - shape [n_systems] - kT (float | torch.Tensor): Target temperature in energy units, either scalar or - shape [n_systems] - external_pressure (float | torch.Tensor): Target external pressure, - either scalar or tensor with shape [n_systems, n_dim, n_dim] - alpha (torch.Tensor): Position friction coefficient, either scalar or - shape [n_systems] - cell_alpha (torch.Tensor): Cell friction coefficient, either scalar or - shape [n_systems] - b_tau (torch.Tensor): Barostat time constant, either scalar or shape [n_systems] + state: Current NPT state + model: Model computing energy, forces, stress + dt: Integration timestep + kT: Target temperature in energy units + external_pressure: Target pressure — scalar (same for all dims), + shape [3] (per-dimension), or [n_systems, 3] Returns: - NPTLangevinState: Updated NPT state after one timestep with new positions, - velocities, cell parameters, forces, energy, and stress + NPTLangevinState: Updated state + + References: + .. [4] Gronbech-Jensen, N. & Farago, O. "Constant pressure and temperature + discrete-time Langevin molecular dynamics." J. Chem. Phys. 141(19) (2014). + .. [5] LAMMPS fix press/langevin: + https://docs.lammps.org/fix_press_langevin.html """ device, dtype = model.device, model.dtype - # Convert any scalar parameters to tensors with batch dimension if needed state.alpha = torch.as_tensor(state.alpha, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) state.cell_alpha = torch.as_tensor(state.cell_alpha, device=device, dtype=dtype) @@ -696,72 +565,510 @@ def npt_langevin_step( external_pressure, device=device, dtype=dtype ) - # Make sure parameters have batch dimension if they're scalars + # Broadcast external_pressure to (n_systems, 3) + if external_pressure_tensor.ndim == 0: + external_pressure_tensor = external_pressure_tensor.expand(state.n_systems, 3) + elif external_pressure_tensor.ndim == 1 and external_pressure_tensor.shape[0] == 3: + external_pressure_tensor = external_pressure_tensor.unsqueeze(0).expand( + state.n_systems, 3 + ) + batch_kT = kT_tensor.expand(state.n_systems) if kT_tensor.ndim == 0 else kT_tensor - # Update barostat mass based on current temperature - # This ensures proper coupling between system and barostat + # Update barostat mass n_atoms_per_system = torch.bincount(state.system_idx) state.cell_masses = (n_atoms_per_system + 1) * batch_kT * torch.square(state.b_tau) - # Compute model output for current state + # Store initial values + forces = state.forces + eps_old = state.cell_positions.clone() + + F_eps_n = _compute_cell_force( + state=state, + external_pressure=external_pressure_tensor, + kT=kT_tensor, + ) + + # Generate GJF noise ONCE + cell_beta = _npt_langevin_cell_beta(state, kT_tensor, dt_tensor) + particle_beta = _npt_langevin_beta(state, kT_tensor, dt_tensor) + + # Step 1: Update per-dimension strain + state = _npt_langevin_cell_position_step(state, dt_tensor, F_eps_n, cell_beta) + + # Reconstruct cell from updated strain + state.cell = state.current_cell + + # Step 2: Update particle positions + state = _npt_langevin_position_step(state, eps_old, dt_tensor, particle_beta) + + # Recompute model output model_output = model(state) + state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] - # Store initial values for integration - forces = state.forces - F_p_n = _compute_cell_force( + # Updated strain force + F_eps_new = _compute_cell_force( state=state, external_pressure=external_pressure_tensor, kT=kT_tensor, ) - L_n = torch.pow( - state.cell_positions.reshape(state.n_systems, -1)[:, 0], 1 / 3 - ) # shape: (n_systems,) - # Step 1: Update cell position - state = _npt_langevin_cell_position_step(state, dt_tensor, F_p_n, kT_tensor) + # Step 3: Update strain velocities (uses SAME cell_beta) + state = _npt_langevin_cell_velocity_step( + state, F_eps_n, dt_tensor, F_eps_new, cell_beta + ) - # Update cell (currently only isotropic fluctuations) - dim = state.positions.shape[1] # Usually 3 for 3D - # V_0 and V are shape: (n_systems,) - V_0 = torch.linalg.det(state.reference_cell) - V = state.cell_positions.reshape(state.n_systems, -1)[:, 0] + # Step 4: Update particle velocities (uses SAME particle_beta) + return _npt_langevin_velocity_step(state, forces, dt_tensor, particle_beta) - # Scale cell uniformly in all dimensions - scaling = (V / V_0) ** (1.0 / dim) # shape: (n_systems,) - # Apply scaling to reference cell to get new cell - new_cell = torch.zeros_like(state.cell) - for sys_idx in range(state.n_systems): - new_cell[sys_idx] = scaling[sys_idx] * state.reference_cell[sys_idx] +# ============================================================================= +# NPT Langevin Strain integrator — isotropic logarithmic strain coordinate +# ============================================================================= - state.cell = new_cell - # Step 2: Update particle positions - state = _npt_langevin_position_step(state, L_n, dt_tensor, kT_tensor) +@dataclass(kw_only=True) +class NPTLangevinStrainState(NPTState): + """State for NPT Langevin dynamics using logarithmic strain coordinate. + + The cell degree of freedom is the isotropic logarithmic strain + ε = (1/d)·ln(V/V₀), which is dimensionless. This guarantees V > 0 + and gives the conjugate force F_ε = d·V·(P_avg - P_ext) in energy units, + providing numerically well-scaled dynamics. + + Attributes: + reference_cell (torch.Tensor): Original cell [n_systems, d, d] + cell_positions (torch.Tensor): Strain ε = (1/d)·ln(V/V₀) [n_systems] + cell_velocities (torch.Tensor): dε/dt [n_systems] + cell_masses (torch.Tensor): Mass for strain DOF [n_systems] + alpha (torch.Tensor): Particle friction [n_systems] + cell_alpha (torch.Tensor): Cell friction [n_systems] + b_tau (torch.Tensor): Barostat time constant [n_systems] + """ + + alpha: torch.Tensor + cell_alpha: torch.Tensor + b_tau: torch.Tensor + + reference_cell: torch.Tensor + cell_positions: torch.Tensor # strain ε (dimensionless) + cell_velocities: torch.Tensor # dε/dt + cell_masses: torch.Tensor + + _system_attributes = NPTState._system_attributes | { # noqa: SLF001 + "cell_positions", + "cell_velocities", + "cell_masses", + "reference_cell", + "alpha", + "cell_alpha", + "b_tau", + } + + @property + def current_cell(self) -> torch.Tensor: + """Compute cell from strain: cell = exp(ε) · reference_cell.""" + scale = torch.exp(self.cell_positions) # exp(ε), shape (n_systems,) + return scale.unsqueeze(-1).unsqueeze(-1) * self.reference_cell + + @property + def volume(self) -> torch.Tensor: + """Current volume V = V₀ · exp(d·ε).""" + dim = self.positions.shape[1] + V_0 = torch.linalg.det(self.reference_cell) + return V_0 * torch.exp(dim * self.cell_positions) + + +def _compute_strain_cell_force( + state: NPTLangevinStrainState, + external_pressure: float | torch.Tensor, + kT: float | torch.Tensor, +) -> torch.Tensor: + """Compute force on the strain coordinate ε. + + F_ε = d · V · (P_avg - P_ext) + + where P_avg = -(1/3)Tr(σ) + NkT/V and d·V is the Jacobian dV/dε. + This force is in energy units (eV), making it numerically well-scaled. + + Args: + state: Current strain-based NPT state + external_pressure: Target pressure (scalar or [n_systems]) + kT: Temperature in energy units (scalar or [n_systems]) + + Returns: + torch.Tensor: Force on strain per system [n_systems] + """ + external_pressure = torch.as_tensor( + external_pressure, device=state.device, dtype=state.dtype + ) + kT = torch.as_tensor(kT, device=state.device, dtype=state.dtype) + + dim = state.positions.shape[1] + volumes = state.volume # (n_systems,) + + # Isotropic virial pressure: P_virial = -(1/3)Tr(stress) + stress_trace = torch.einsum("nii->n", state.stress) + avg_virial_pressure = -stress_trace / 3 # (n_systems,) + + # Kinetic contribution: NkT/V + batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) + n_atoms = state.n_atoms_per_system.to(dtype=state.dtype) + kinetic_pressure = n_atoms * batch_kT / volumes # (n_systems,) + + if external_pressure.ndim >= 2: + raise ValueError( + f"External pressure tensor provided with shape {external_pressure.shape}. " + "Only scalar or per-system external pressure is supported." + ) + + P_avg = avg_virial_pressure + kinetic_pressure + # F_ε = d · V · (P_avg - P_ext) + return dim * volumes * (P_avg - external_pressure) + - # Recompute model output after position updates +def _npt_langevin_strain_cell_beta( + state: NPTLangevinStrainState, + kT: torch.Tensor, + dt: torch.Tensor, +) -> torch.Tensor: + """Generate scalar random noise for isotropic strain fluctuations. + + Returns: + torch.Tensor: Noise [n_systems] + """ + noise = _randn_for_state(state, (state.n_systems,)) + batch_kT = kT if kT.ndim > 0 else kT.expand(state.n_systems) + dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) + scaling = torch.sqrt(2.0 * state.cell_alpha * batch_kT * dt_expanded) + return scaling * noise + + +def _npt_langevin_strain_cell_position_step( + state: NPTLangevinStrainState, + dt: torch.Tensor, + strain_force: torch.Tensor, + cell_beta: torch.Tensor, +) -> NPTLangevinStrainState: + """GJF position step for the strain coordinate ε. + + ε_{n+1} = ε_n + b·dt·dε/dt + b·dt²·F_ε/(2Q) + b·dt·β/(2Q) + + Args: + state: Current state + dt: Timestep + strain_force: F_ε [n_systems] + cell_beta: Noise term β_c [n_systems] + + Returns: + Updated state with new cell_positions (strain) + """ + Q_2 = 2 * state.cell_masses + dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) + + cell_b = 1 / (1 + (state.cell_alpha * dt_expanded) / Q_2) + + c_1 = cell_b * dt_expanded * state.cell_velocities + c_2 = cell_b * dt_expanded * dt_expanded * strain_force / Q_2 + c_3 = cell_b * dt_expanded * cell_beta / Q_2 + + state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 + return state + + +def _npt_langevin_strain_cell_velocity_step( + state: NPTLangevinStrainState, + F_eps_n: torch.Tensor, + dt: torch.Tensor, + strain_force: torch.Tensor, + cell_beta: torch.Tensor, +) -> NPTLangevinStrainState: + """GJF velocity step for the strain coordinate ε. + + dε/dt_{n+1} = a·dε/dt_n + dt/(2Q)·(a·F_ε^n + F_ε^{n+1}) + b·β/Q + + Args: + state: Current state + F_eps_n: Initial strain force [n_systems] + dt: Timestep + strain_force: Final strain force [n_systems] + cell_beta: Noise term β_c (SAME as in position step) [n_systems] + + Returns: + Updated state with new cell_velocities (dε/dt) + """ + dt_expanded = dt if dt.ndim > 0 else dt.expand(state.n_systems) + + Q = state.cell_masses + a = (1 - (state.cell_alpha * dt_expanded) / (2 * Q)) / ( + 1 + (state.cell_alpha * dt_expanded) / (2 * Q) + ) + b = 1 / (1 + (state.cell_alpha * dt_expanded) / (2 * Q)) + + c_1 = a * state.cell_velocities + c_2 = dt_expanded * ((a * F_eps_n) + strain_force) / (2 * Q) + c_3 = b * cell_beta / Q + + state.cell_velocities = c_1 + c_2 + c_3 + return state + + +def _npt_langevin_strain_position_step( + state: NPTLangevinStrainState, + eps_old: torch.Tensor, + dt: torch.Tensor, + particle_beta: torch.Tensor, +) -> NPTLangevinStrainState: + """Update particle positions accounting for strain change. + + Positions are scaled by exp(ε_new - ε_old) for the volume change, + then the standard GJF position update is applied. + + Args: + state: Current state (cell_positions already updated to ε_new) + eps_old: Strain before the cell position step [n_systems] + dt: Timestep + particle_beta: Noise [n_particles, n_dim] + + Returns: + Updated state with new positions + """ + M_2 = 2 * state.masses.unsqueeze(-1) # (n_atoms, 1) + + # Scale factor from strain change: L_new/L_old = exp(ε_new - ε_old) + scale = torch.exp(state.cell_positions - eps_old) # (n_systems,) + scale_atoms = scale[state.system_idx] # (n_atoms,) + + # Damping factor + alpha_atoms = state.alpha[state.system_idx] + dt_atoms = dt + if dt.ndim > 0: + dt_atoms = dt[state.system_idx] + + b = 1 / (1 + ((alpha_atoms * dt_atoms) / (2 * state.masses))) + + # Scale positions due to volume change + c_1 = scale_atoms.unsqueeze(-1) * state.positions + + # Time step factor: 2·s/(s+1) where s = scale + c_2 = (2 * scale_atoms / (scale_atoms + 1)) * b * dt_atoms + + c_3 = ( + state.velocities + + dt_atoms.unsqueeze(-1) * state.forces / M_2 + + particle_beta / M_2 + ) + + state.set_constrained_positions(c_1 + c_2.unsqueeze(-1) * c_3) + return state + + +def npt_langevin_strain_init( + state: SimState, + model: ModelInterface, + *, + kT: float | torch.Tensor, + dt: float | torch.Tensor, + alpha: float | torch.Tensor | None = None, + cell_alpha: float | torch.Tensor | None = None, + b_tau: float | torch.Tensor | None = None, + **_kwargs: Any, +) -> NPTLangevinStrainState: + """Initialize an NPT Langevin state using logarithmic strain coordinate. + + The strain coordinate ε = (1/d)·ln(V/V₀) provides well-scaled dynamics + where the conjugate force F_ε = d·V·(P_avg - P_ext) is in energy units. + + Args: + state: Initial SimState + model: Model that computes energy, forces, stress + kT: Target temperature in energy units + dt: Integration timestep + alpha: Particle friction coefficient. Defaults to 1/(5·dt). + cell_alpha: Cell friction coefficient. Defaults to 1/(30·dt). + b_tau: Barostat time constant. Defaults to 300·dt. + + Returns: + NPTLangevinStrainState: Initialized state with ε = 0 + """ + device, dtype = model.device, model.dtype + + if alpha is None: + alpha = 1.0 / (5 * dt) + if cell_alpha is None: + cell_alpha = 1.0 / (30 * dt) + if b_tau is None: + b_tau = 300 * dt + + alpha = torch.as_tensor(alpha, device=device, dtype=dtype) + cell_alpha = torch.as_tensor(cell_alpha, device=device, dtype=dtype) + b_tau = torch.as_tensor(b_tau, device=device, dtype=dtype) + kT = torch.as_tensor(kT, device=device, dtype=dtype) + dt = torch.as_tensor(dt, device=device, dtype=dtype) + + if alpha.ndim == 0: + alpha = alpha.expand(state.n_systems) + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_systems) + if b_tau.ndim == 0: + b_tau = b_tau.expand(state.n_systems) + + model_output = model(state) + + momenta = getattr(state, "momenta", None) + if momenta is None: + momenta = initialize_momenta( + state.positions, + state.masses, + state.system_idx, + kT, + state.rng, + ) + + reference_cell = state.cell.clone() + + # ε = 0 at initialization (V = V₀) + cell_positions = torch.zeros(state.n_systems, device=device, dtype=dtype) + cell_velocities = torch.zeros(state.n_systems, device=device, dtype=dtype) + + batch_kT = kT.expand(state.n_systems) if kT.ndim == 0 else kT + n_atoms_per_system = torch.bincount(state.system_idx) + cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau + + if state.constraints: + msg = ( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Langevin dynamics. " + "We recommend not using constraints with NPT dynamics for now." + ) + warnings.warn(msg, UserWarning, stacklevel=3) + logger.warning(msg) + + return NPTLangevinStrainState.from_state( + state, + momenta=momenta, + energy=model_output["energy"], + forces=model_output["forces"], + stress=model_output["stress"], + alpha=alpha, + b_tau=b_tau, + reference_cell=reference_cell, + cell_positions=cell_positions, + cell_velocities=cell_velocities, + cell_masses=cell_masses, + cell_alpha=cell_alpha, + ) + + +@dcite("10.1063/1.4901303") +def npt_langevin_strain_step( + state: NPTLangevinStrainState, + model: ModelInterface, + *, + dt: float | torch.Tensor, + kT: float | torch.Tensor, + external_pressure: float | torch.Tensor, +) -> NPTLangevinStrainState: + r"""Perform one NPT Langevin step using logarithmic strain coordinate. + + Uses the same GJF integrator as :func:`npt_langevin_step` but with the + cell degree of freedom being the isotropic logarithmic strain + :math:`\varepsilon = \frac{1}{d}\ln(V/V_0)` instead of the raw volume. + + **Strain force:** + + .. math:: + + F_\varepsilon = d \cdot V \cdot (P_{\text{avg}} - P_{\text{ext}}) + + where the Jacobian :math:`dV/d\varepsilon = d \cdot V` naturally provides + a volume factor that makes :math:`F_\varepsilon` an energy (eV), giving + numerically well-scaled dynamics. + + **Cell reconstruction:** + + .. math:: + + V = V_0 \exp(d\,\varepsilon), \quad + \mathbf{h} = e^\varepsilon \, \mathbf{h}_0 + + **Particle scaling:** + + .. math:: + + \mathbf{r}_i \to e^{\varepsilon_{n+1} - \varepsilon_n} \, \mathbf{r}_i + + Args: + state: Current strain-based NPT state + model: Model computing energy, forces, stress + dt: Integration timestep + kT: Target temperature in energy units + external_pressure: Target pressure + + Returns: + NPTLangevinStrainState: Updated state + """ + device, dtype = model.device, model.dtype + + state.alpha = torch.as_tensor(state.alpha, device=device, dtype=dtype) + kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) + state.cell_alpha = torch.as_tensor(state.cell_alpha, device=device, dtype=dtype) + dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) + external_pressure_tensor = torch.as_tensor( + external_pressure, device=device, dtype=dtype + ) + + batch_kT = kT_tensor.expand(state.n_systems) if kT_tensor.ndim == 0 else kT_tensor + + # Update barostat mass + n_atoms_per_system = torch.bincount(state.system_idx) + state.cell_masses = (n_atoms_per_system + 1) * batch_kT * torch.square(state.b_tau) + + # Store initial values + forces = state.forces + eps_old = state.cell_positions.clone() + + F_eps_n = _compute_strain_cell_force( + state=state, + external_pressure=external_pressure_tensor, + kT=kT_tensor, + ) + + # Generate GJF noise ONCE + cell_beta = _npt_langevin_strain_cell_beta(state, kT_tensor, dt_tensor) + particle_beta = _npt_langevin_beta(state, kT_tensor, dt_tensor) + + # Step 1: Update strain (cell position step) + state = _npt_langevin_strain_cell_position_step(state, dt_tensor, F_eps_n, cell_beta) + + # Reconstruct cell from updated strain + state.cell = state.current_cell + + # Step 2: Update particle positions (with strain-based scaling) + state = _npt_langevin_strain_position_step(state, eps_old, dt_tensor, particle_beta) + + # Recompute model output model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] - # Compute updated pressure force - F_p_n_new = _compute_cell_force( + # Compute updated strain force + F_eps_new = _compute_strain_cell_force( state=state, external_pressure=external_pressure_tensor, kT=kT_tensor, ) - # Step 3: Update cell velocities - state = _npt_langevin_cell_velocity_step( - state, F_p_n, dt_tensor, F_p_n_new, kT_tensor + # Step 3: Update strain velocity (uses SAME cell_beta) + state = _npt_langevin_strain_cell_velocity_step( + state, F_eps_n, dt_tensor, F_eps_new, cell_beta ) - # Step 4: Update particle velocities - return _npt_langevin_velocity_step(state, forces, dt_tensor, kT_tensor) + # Step 4: Update particle velocities (uses SAME particle_beta) + return _npt_langevin_velocity_step(state, forces, dt_tensor, particle_beta) @dataclass(kw_only=True) @@ -1179,9 +1486,9 @@ def _npt_nose_hoover_compute_cell_force( internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems) # Compute force on cell coordinate per system - # F = alpha * KE - dU/dV - P*V*d + # F = alpha * (2 * KE) - dU/dV - P*V*d return ( - (alpha * KE_per_system) + (alpha * 2 * KE_per_system) - (internal_pressure * volume) - (external_pressure * volume * dim) ) @@ -1226,21 +1533,18 @@ def _npt_nose_hoover_inner_step( volume, volume_to_cell = _npt_nose_hoover_cell_info(state) cell = volume_to_cell(volume) - # Get model output - state.cell = cell - model_output = model(state) - # First half step: Update momenta - n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) - alpha = 1 + 1 / n_atoms_per_system # [n_systems] + # alpha = 1 + dim / degrees_of_freedom (3 * natoms - 3) + alpha = 1 + 3 / state.get_number_of_degrees_of_freedom() # [n_systems] + # Reuse stress from previous step since positions and cell unchanged cell_force_val = _npt_nose_hoover_compute_cell_force( alpha=alpha, volume=volume, positions=positions, momenta=momenta, masses=masses, - stress=model_output["stress"], + stress=state.stress, external_pressure=external_pressure, system_idx=state.system_idx, ) @@ -1357,10 +1661,10 @@ def npt_nose_hoover_init( dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) t_tau_tensor = torch.as_tensor( - 100 * dt_tensor if t_tau is None else t_tau, device=device, dtype=dtype + 10 * dt_tensor if t_tau is None else t_tau, device=device, dtype=dtype ) b_tau_tensor = torch.as_tensor( - 1000 * dt_tensor if b_tau is None else b_tau, device=device, dtype=dtype + 100 * dt_tensor if b_tau is None else b_tau, device=device, dtype=dtype ) # Setup thermostats with appropriate timescales @@ -1406,7 +1710,8 @@ def npt_nose_hoover_init( ) # Compute total DOF for thermostat initialization and a zero KE placeholder - dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim + dof_per_system = state.get_number_of_degrees_of_freedom() - 3 + KE_thermostat = ts.calc_kinetic_energy( masses=state.masses, momenta=momenta, system_idx=state.system_idx ) @@ -1472,26 +1777,102 @@ def npt_nose_hoover_step( kT: float | torch.Tensor, external_pressure: float | torch.Tensor, ) -> NPTNoseHooverState: - """Perform a complete NPT integration step with Nose-Hoover chain thermostats. + r"""Perform a complete NPT integration step with Nose-Hoover chain thermostats. + + Implements the MTK (Martyna-Tobias-Klein) NPT scheme from Tuckerman et al. + (2006) [10]_ with Nose-Hoover chains from Martyna et al. (1996) [3]_. + + **Equations of motion** (Tuckerman et al. 2006, Eqs. 1-6): + + .. math:: + + \dot{\mathbf{r}}_i &= \frac{\mathbf{p}_i}{m_i} + + \frac{p_\epsilon}{W}\,\mathbf{r}_i \\ + \dot{\mathbf{p}}_i &= \mathbf{F}_i + - \alpha\,\frac{p_\epsilon}{W}\,\mathbf{p}_i \\ + \dot{\epsilon} &= \frac{p_\epsilon}{W} \\ + \dot{p}_\epsilon &= G_\epsilon + = \alpha\,(2K) + \text{Tr}(\boldsymbol{\sigma}_{\text{int}})\,V + - P_{\text{ext}}\,V\,d + + where :math:`\epsilon = (1/d)\ln(V/V_0)` is the logarithmic cell coordinate, + :math:`\alpha = 1 + d/N_f`, :math:`d=3` is spatial dimension, and + :math:`N_f = 3N - 3` the degrees of freedom. + + **Symmetric propagator** (Trotter factorization): + + .. math:: + + e^{i\mathcal{L}\Delta t} = + e^{i\mathcal{L}_{\text{NHC-baro}}\frac{\Delta t}{2}} + \;e^{i\mathcal{L}_{\text{NHC-part}}\frac{\Delta t}{2}} + \;e^{i\mathcal{L}_2\frac{\Delta t}{2}} + \;e^{i\mathcal{L}_1\Delta t} + \;e^{i\mathcal{L}_2\frac{\Delta t}{2}} + \;e^{i\mathcal{L}_{\text{NHC-part}}\frac{\Delta t}{2}} + \;e^{i\mathcal{L}_{\text{NHC-baro}}\frac{\Delta t}{2}} + + **Position update** :math:`e^{i\mathcal{L}_1\Delta t}`: + + .. math:: + + \mathbf{r}_i \leftarrow \mathbf{r}_i + + \bigl(e^{v_\epsilon\Delta t} - 1\bigr)\,\mathbf{r}_i + + \Delta t\,\mathbf{v}_i\,e^{v_\epsilon\Delta t/2} + \,\frac{\sinh(v_\epsilon\Delta t/2)}{v_\epsilon\Delta t/2} + + **Momentum update** :math:`e^{i\mathcal{L}_2\Delta t/2}`: + + .. math:: + + \mathbf{p}_i \leftarrow \mathbf{p}_i\,e^{-\alpha v_\epsilon\Delta t/2} + + \frac{\Delta t}{2}\,\mathbf{F}_i\, + e^{-\alpha v_\epsilon\Delta t/4} + \,\frac{\sinh(\alpha v_\epsilon\Delta t/4)} + {\alpha v_\epsilon\Delta t/4} + + where :math:`v_\epsilon = p_\epsilon / W` is the cell velocity. + + **Variable mapping (equation -> code):** + + ============================================ ============================ + Equation symbol Code variable + ============================================ ============================ + :math:`\mathbf{r}_i` (positions) ``state.positions`` + :math:`\mathbf{p}_i` (momenta) ``state.momenta`` + :math:`m_i` (masses) ``state.masses`` + :math:`\mathbf{F}_i` (forces) ``state.forces`` + :math:`\epsilon` (log-cell coordinate) ``state.cell_position`` + :math:`p_\epsilon` (cell momentum) ``state.cell_momentum`` + :math:`W` (cell mass) ``state.cell_mass`` + :math:`\alpha` (1 + d/Nf) ``alpha`` (local) + :math:`v_\epsilon` (cell velocity) ``cell_velocities`` (local) + :math:`V_0` (reference volume) ``det(state.reference_cell)`` + :math:`G_\epsilon` (cell force) ``cell_force_val`` + :math:`P_{\text{ext}}` (target pressure) ``external_pressure`` + :math:`k_BT` (thermal energy) ``kT`` + :math:`\Delta t` (timestep) ``dt`` + ============================================ ============================ + If the center of mass motion is removed initially, it remains removed throughout the simulation, so the degrees of freedom decreases by 3. - This function performs a full NPT integration step including: - 1. Mass parameter updates for thermostats and cell - 2. Thermostat chain updates (half step) - 3. Inner NPT dynamics step - 4. Energy updates for thermostats - 5. Final thermostat chain updates (half step) - Args: - model (ModelInterface): Model to compute forces and energies - state (NPTNoseHooverState): Current system state - dt (float | torch.Tensor): Integration timestep - kT (float | torch.Tensor): Target temperature - external_pressure (float | torch.Tensor): Target external pressure + model: Model to compute forces and energies + state: Current system state + dt: Integration timestep + kT: Target temperature + external_pressure: Target external pressure Returns: NPTNoseHooverState: Updated state after complete integration step + + References: + .. [10] Tuckerman, M. E., et al. "A Liouville-operator derived + measure-preserving integrator for molecular dynamics simulations in + the isothermal-isobaric ensemble." J. Phys. A 39(19), 5629-5651 (2006). + .. [3] Martyna, G. J., et al. "Explicit reversible integrators for extended + systems dynamics." Mol. Phys. 87(5), 1117-1157 (1996). """ device, dtype = model.device, model.dtype dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) @@ -1612,13 +1993,12 @@ def npt_nose_hoover_invariant( ) # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems) - dof_per_system = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dim + dof_per_system = state.get_number_of_degrees_of_freedom() # Initialize total energy with PE + KE e_tot = e_pot + e_kin_per_system - # Add thermostat chain contributions (batched per system, DOF = n_atoms * 3) + # Add thermostat chain contributions (batched per system, DOF = 3 * n_atoms - 3) e_tot += _compute_chain_energy(state.thermostat, kT, e_tot, dof_per_system) # Add barostat chain contributions (batched per system, DOF = 1) @@ -1741,7 +2121,7 @@ def _crescale_anisotropic_barostat_step( ## Step 2: compute deformation matrix random_coeff = 2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p) prefactor_random_matrix = torch.sqrt(random_coeff) / new_sqrt_volume - a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( + a_tilde = (state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( P_int - trace_P_int[:, None, None] / 3 @@ -1783,7 +2163,8 @@ def _crescale_anisotropic_barostat_step( (vscaling + rscaling)[state.system_idx], state.momenta ) * dt / (2 * state.masses.unsqueeze(-1)) state.momenta = batch_matrix_vector(vscaling[state.system_idx], state.momenta) - state.cell = rscaling.mT @ state.cell + # Right multiply: cell @ rscaling^T preserves fractional coordinates + state.cell = state.cell @ rscaling.mT return state @@ -1818,7 +2199,7 @@ def _crescale_independent_lengths_barostat_step( ) # Note: it corresponds to using a diagonal isothermal compressibility tensor P_int_diagonal = torch.diagonal(P_int, dim1=-2, dim2=-1) - a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None] * ( + a_tilde = (state.isothermal_compressibility / (3 * state.tau_p))[:, None] * ( P_int_diagonal - trace_P_int[:, None] / 3 ) @@ -1889,8 +2270,7 @@ def _crescale_average_anisotropic_barostat_step( ) -> NPTCRescaleState: volume = torch.det(state.cell) # shape: (n_systems,) P_int = compute_average_pressure_tensor( - # Should it be degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, - degrees_of_freedom=state.n_atoms_per_system, + degrees_of_freedom=state.get_number_of_degrees_of_freedom() / 3, kT=kT, stress=state.stress, volumes=volume, @@ -1910,7 +2290,7 @@ def _crescale_average_anisotropic_barostat_step( torch.sqrt(2 * state.isothermal_compressibility * kT * dt / (3 * state.tau_p)) / new_sqrt_volume ) - a_tilde = -(state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( + a_tilde = (state.isothermal_compressibility / (3 * state.tau_p))[:, None, None] * ( P_int - trace_P_int[:, None, None] / 3 @@ -1956,7 +2336,8 @@ def _crescale_average_anisotropic_barostat_step( )[state.system_idx], state.momenta, ) * dt / (2 * state.masses.unsqueeze(-1)) - state.cell = rscaling.mT @ state.cell + # Right multiply: cell @ rscaling^T preserves fractional coordinates + state.cell = state.cell @ rscaling.mT return state @@ -1982,7 +2363,9 @@ def _crescale_isotropic_barostat_step( prefactor = state.isothermal_compressibility * sqrt_vol / (2 * state.tau_p) change_sqrt_vol = -prefactor * ( external_pressure - trace_P_int / 3 - kT / (2 * volume) - ) * dt + prefactor_random * _randn_for_state(state, sqrt_vol.shape) + ) * dt + torch.sqrt( + 2 * torch.ones_like(sqrt_vol) + ) * prefactor_random * _randn_for_state(state, sqrt_vol.shape) new_sqrt_volume = sqrt_vol + change_sqrt_vol # Update positions and momenta (barostat + half momentum step) @@ -2013,7 +2396,7 @@ def _coerce_crescale_step_inputs( external_pressure, device=device, dtype=dtype ) tau_tensor = torch.as_tensor( - 100 * dt_tensor if tau is None else tau, device=device, dtype=dtype + 1 * dt_tensor if tau is None else tau, device=device, dtype=dtype ) return dt_tensor, kT_tensor, external_pressure_tensor, tau_tensor @@ -2029,41 +2412,96 @@ def npt_crescale_anisotropic_step( external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: - """Perform one NPT integration step with cell rescaling barostat. + r"""Perform one NPT integration step with anisotropic stochastic cell rescaling. - This function performs a single integration step for NPT dynamics using - a cell rescaling barostat. It updates particle positions, momenta, and - the simulation cell based on the target temperature and pressure. + Implements the anisotropic C-Rescale barostat from Del Tatto et al. + (2022) [7]_ extending the isotropic scheme of Bernetti & Bussi (2020) [6]_. + Cell lengths and angles can change independently. Uses instantaneous kinetic + energy. Both positions and momenta are scaled. - Trotter based splitting: - 1. Half Thermostat (velocity scaling) - 2. Half Update momenta with forces - 3. Barostat (cell rescaling) - 4. Update positions (from barostat + half momenta) - 5. Update forces with new positions and cell - 6. Compute forces - 7. Half Update momenta with forces - 8. Half Thermostat (velocity scaling) + **Trotter splitting:** - Only allow isotropic external stress. This method performs anisotropic - cell rescaling. Lengths and angles can change independently. Based on - pressure using kinetic energy. Positions and momenta are scaled when scaling the cell. + V-Rescale(dt/2) -> B(dt/2) -> Barostat(dt) -> Force eval -> B(dt/2) -> + V-Rescale(dt/2) - Inspired from: https://github.com/bussilab/crescale/blob/master/simplemd_anisotropic/simplemd.cpp - - Time reversible integrator - - Instantaneous kinetic energy (not not the average from equipartition) + **Barostat sub-steps** (3-step volume + deformation update): + + Step 1 -- Propagate :math:`\sqrt{V}` for :math:`\Delta t/2` (same SDE as + isotropic, Eq. 7 of [6]_): + + .. math:: + + \Delta\lambda = -\frac{\beta_T\lambda}{2\tau_p} + \left(P_0 - \frac{\text{Tr}(\mathbf{P}_{\text{int}})}{3} + - \frac{k_BT}{2V}\right)\frac{\Delta t}{2} + + \sqrt{\frac{k_BT\beta_T\Delta t}{4\tau_p}}\;R + + Step 2 -- Compute deviatoric deformation matrix: + + .. math:: + + \tilde{\mathbf{A}} &= \frac{\beta_T}{3\tau_p} + \left(\mathbf{P}_{\text{int}} + - \frac{\text{Tr}(\mathbf{P}_{\text{int}})}{3}\,\mathbf{I}\right) \\ + \boldsymbol{\mu}_{\text{dev}} &= \exp\bigl(\tilde{\mathbf{A}}\,\Delta t + + \sigma\,\tilde{\mathbf{R}}\bigr) + + where :math:`\sigma = \sqrt{2\beta_T k_BT\Delta t/(3\tau_p)}\;/\;\sqrt{V'}` + and :math:`\tilde{\mathbf{R}}` is a traceless random matrix. + + Step 3 -- Propagate :math:`\sqrt{V}` for :math:`\Delta t/2` (same as step 1). + + **Total scaling and update:** + + .. math:: + + \boldsymbol{\mu} &= \boldsymbol{\mu}_{\text{dev}} + \cdot (V'/V)^{1/3} \\ + \mathbf{r}_i &\leftarrow \boldsymbol{\mu}\,\mathbf{r}_i + + (\boldsymbol{\mu}^{-T} + \boldsymbol{\mu})\, + \frac{\mathbf{p}_i}{2m_i}\,\Delta t \\ + \mathbf{p}_i &\leftarrow \boldsymbol{\mu}^{-T}\,\mathbf{p}_i \\ + \mathbf{h} &\leftarrow \mathbf{h}\,\boldsymbol{\mu}^T + + **Variable mapping (equation -> code):** + + ============================================ ================================ + Equation symbol Code variable + ============================================ ================================ + :math:`V` (volume) ``volume`` + :math:`\lambda` (:math:`\sqrt{V}`) ``sqrt_vol`` + :math:`\beta_T` (compressibility) ``state.isothermal_compressibility`` + :math:`\tau_p` (barostat relax. time) ``state.tau_p`` + :math:`P_0` (target pressure) ``external_pressure`` + :math:`\mathbf{P}_{\text{int}}` (press. tensor) ``P_int`` + :math:`\tilde{\mathbf{A}}` (deviator drive) ``a_tilde`` + :math:`\boldsymbol{\mu}_{\text{dev}}` ``deformation_matrix`` + :math:`\boldsymbol{\mu}` (total scaling) ``rscaling`` + :math:`\boldsymbol{\mu}^{-T}` (mom. scaling) ``vscaling`` + :math:`\tilde{\mathbf{R}}` (traceless noise) ``random_matrix_tilde`` + :math:`\sigma` (noise prefactor) ``prefactor_random_matrix`` + :math:`k_BT` (thermal energy) ``kT`` + :math:`\Delta t` (timestep) ``dt`` + :math:`\tau` (thermostat relax.) ``tau`` (V-Rescale) + ============================================ ================================ Args: - model (ModelInterface): Model to compute forces and energies - state (NPTCRescaleState): Current system state - dt (torch.Tensor): Integration timestep - kT (torch.Tensor): Target temperature - external_pressure (torch.Tensor): Target external pressure - tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, - defaults to 100*dt + model: Model to compute forces and energies + state: Current system state + dt: Integration timestep + kT: Target temperature + external_pressure: Target external pressure + tau: V-Rescale thermostat relaxation time. If None, defaults to 100*dt Returns: NPTCRescaleState: Updated state after one integration step + + References: + .. [7] Del Tatto, V., et al. "Molecular dynamics of solids at constant + pressure and stress using anisotropic stochastic cell rescaling." + Applied Sciences 12(3), 1139 (2022). + .. [6] Bernetti, M. & Bussi, G. "Pressure control using stochastic cell + rescaling." J. Chem. Phys. 153, 114107 (2020). """ dt_tensor, kT_tensor, external_pressure_tensor, tau_tensor = ( _coerce_crescale_step_inputs(state, dt, kT, external_pressure, tau) @@ -2219,7 +2657,7 @@ def npt_crescale_average_anisotropic_step( external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) # Note: would probably be better to have tau in NVTCRescaleState - tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) + tau = torch.as_tensor(tau or 1 * dt, device=device, dtype=dtype) state = _vrescale_update(state, tau, kT, dt / 2) @@ -2251,44 +2689,74 @@ def npt_crescale_isotropic_step( external_pressure: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NPTCRescaleState: - """Perform one NPT integration step with cell rescaling barostat. + r"""Perform one NPT integration step with isotropic stochastic cell rescaling. - This function performs a single integration step for NPT dynamics using - a cell rescaling barostat. It updates particle positions, momenta, and - the simulation cell based on the target temperature and pressure. + Implements isotropic C-Rescale from Bernetti & Bussi (2020) [6]_. + Cell shape is preserved; cell lengths are scaled equally. - Trotter based splitting: - 1. Half Thermostat (velocity scaling) - 2. Half Update momenta with forces - 3. Barostat (cell rescaling) - 4. Update positions (from barostat + half momenta) - 5. Update forces with new positions and cell - 6. Compute forces - 7. Half Update momenta with forces - 8. Half Thermostat (velocity scaling) + **Trotter splitting:** - Only allow isotropic external stress. This performs isotropic - cell rescaling: cell shape is preserved, cell lengths are scaled equally. - For anisotropic cell rescaling, use npt_crescale_anisotropic_step. + V-Rescale(dt/2) -> B(dt/2) -> Barostat(dt) -> Force eval -> B(dt/2) -> + V-Rescale(dt/2) - References: - - Bernetti, Mattia, and Giovanni Bussi. - "Pressure control using stochastic cell rescaling." - The Journal of Chemical Physics 153.11 (2020). - - And the corresponding Supplementary Information which details - the integration scheme. Notice an error in scaling of positions in SI Eq. S13a. + **Isotropic volume SDE** (Eq. 7 of [6]_, using :math:`\lambda = \sqrt{V}`): + + .. math:: + + d\lambda = -\frac{\beta_T\lambda}{2\tau_p} + \left(P_0 - \frac{\text{Tr}(\mathbf{P}_{\text{int}})}{3} + - \frac{k_BT}{2V}\right) dt + + \sqrt{\frac{k_BT\,\beta_T}{2\tau_p}}\;dW + + where :math:`\beta_T` is the isothermal compressibility and + :math:`\mathbf{P}_{\text{int}}` is the instantaneous pressure tensor + (including the kinetic contribution). + + **Position and momentum scaling** (SI Eqs. S13a-b of [6]_, corrected): + + .. math:: + + \mathbf{r}_i &\leftarrow \mu\,\mathbf{r}_i + + (\mu + \mu^{-1})\,\frac{\mathbf{p}_i}{2m_i}\,\Delta t \\ + \mathbf{p}_i &\leftarrow \mu^{-1}\,\mathbf{p}_i \\ + \mathbf{h} &\leftarrow \mu\,\mathbf{h} + + where :math:`\mu = (V'/V)^{1/3}` is the isotropic scaling factor and + :math:`\mathbf{h}` is the cell matrix. + + **Variable mapping (equation -> code):** + + ============================================ ================================ + Equation symbol Code variable + ============================================ ================================ + :math:`V` (volume) ``volume`` + :math:`\lambda` (:math:`\sqrt{V}`) ``sqrt_vol`` + :math:`\beta_T` (compressibility) ``state.isothermal_compressibility`` + :math:`\tau_p` (barostat relax. time) ``state.tau_p`` + :math:`P_0` (target pressure) ``external_pressure`` + :math:`\mathbf{P}_{\text{int}}` (press. tensor) ``P_int`` + :math:`\text{Tr}(\mathbf{P}_{\text{int}})` ``trace_P_int`` + :math:`\mu` (scaling factor) ``rscaling`` + :math:`k_BT` (thermal energy) ``kT`` + :math:`\Delta t` (timestep) ``dt`` + :math:`\tau` (thermostat relax.) ``tau`` (V-Rescale) + ============================================ ================================ Args: - model (ModelInterface): Model to compute forces and energies - state (NPTCRescaleState): Current system state - dt (torch.Tensor): Integration timestep - kT (torch.Tensor): Target temperature - external_pressure (torch.Tensor): Target external pressure - tau (torch.Tensor | None): V-Rescale thermostat relaxation time. If None, - defaults to 100*dt + model: Model to compute forces and energies + state: Current system state + dt: Integration timestep + kT: Target temperature + external_pressure: Target external pressure + tau: V-Rescale thermostat relaxation time. If None, defaults to 100*dt Returns: NPTCRescaleState: Updated state after one integration step + + References: + .. [6] Bernetti, M. & Bussi, G. "Pressure control using stochastic cell + rescaling." J. Chem. Phys. 153, 114107 (2020). Note: SI Eq. S13a has a + typo (positions must also be scaled by mu). """ device, dtype = model.device, model.dtype dt = torch.as_tensor(dt, device=device, dtype=dtype) @@ -2296,7 +2764,7 @@ def npt_crescale_isotropic_step( external_pressure = torch.as_tensor(external_pressure, device=device, dtype=dtype) # Note: would probably be better to have tau in NVTCRescaleState - tau = torch.as_tensor(tau or 100 * dt, device=device, dtype=dtype) + tau = torch.as_tensor(tau or 1 * dt, device=device, dtype=dtype) state = _vrescale_update(state, tau, kT, dt / 2) @@ -2354,11 +2822,10 @@ def npt_crescale_init( kT = torch.as_tensor(kT, device=device, dtype=dtype) # Set default values if not provided - tau_p = torch.as_tensor( - tau_p or 5000 * dt, device=device, dtype=dtype - ) # 5ps for dt=1fs + tau_p = torch.as_tensor(tau_p or 3 * dt, device=device, dtype=dtype) # 5ps for dt=1fs isothermal_compressibility = torch.as_tensor( - isothermal_compressibility or 1e-1, + isothermal_compressibility + or 1e-6 / MetalUnits.pressure, # 1e-6 bar^-1 for metals device=device, dtype=dtype, # (eV/A^3)^-1 ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 07f3064b..d37ff7ed 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -68,14 +68,35 @@ def nve_init( def nve_step( state: MDState, model: ModelInterface, *, dt: float | torch.Tensor, **_kwargs: Any ) -> MDState: - """Perform one complete NVE (microcanonical) integration step. + r"""Perform one complete NVE (microcanonical) integration step. - This function implements the velocity Verlet algorithm for NVE dynamics, - which provides energy-conserving time evolution. The integration sequence is: - 1. Half momentum update using current forces - 2. Full position update using updated momenta - 3. Force update at new positions - 4. Half momentum update using new forces + Implements the velocity Verlet algorithm for NVE dynamics, which provides + energy-conserving, time-reversible integration of Hamilton's equations of motion. + + **Equations** (standard velocity Verlet): + + .. math:: + + \mathbf{p}_i(t + \Delta t/2) &= \mathbf{p}_i(t) + + \frac{\Delta t}{2}\,\mathbf{F}_i(t) \\ + \mathbf{r}_i(t + \Delta t) &= \mathbf{r}_i(t) + + \Delta t\,\frac{\mathbf{p}_i(t + \Delta t/2)}{m_i} \\ + \mathbf{F}_i(t + \Delta t) &= -\nabla_{\mathbf{r}_i} U\bigl( + \mathbf{r}(t + \Delta t)\bigr) \\ + \mathbf{p}_i(t + \Delta t) &= \mathbf{p}_i(t + \Delta t/2) + + \frac{\Delta t}{2}\,\mathbf{F}_i(t + \Delta t) + + **Variable mapping (equation -> code):** + + ============================================ ============================ + Equation symbol Code variable + ============================================ ============================ + :math:`\mathbf{r}_i` (positions) ``state.positions`` + :math:`\mathbf{p}_i` (momenta) ``state.momenta`` + :math:`m_i` (masses) ``state.masses`` + :math:`\mathbf{F}_i` (forces) ``state.forces`` + :math:`\Delta t` (timestep) ``dt`` + ============================================ ============================ Args: model: Neural network model that computes energies and forces. @@ -88,10 +109,8 @@ def nve_step( momenta, forces, and energy Notes: - - Uses velocity Verlet algorithm for time reversible integration + - Symplectic, time-reversible integrator of second order accuracy O(dt^2) - Conserves energy in the absence of numerical errors - - Handles periodic boundary conditions if enabled in state - - Symplectic integrator preserving phase space volume """ dt = torch.as_tensor(dt, device=state.device, dtype=state.dtype) state = momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8e74bf85..eadfd4ad 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -143,16 +143,54 @@ def nvt_langevin_step( kT: float | torch.Tensor, gamma: float | torch.Tensor | None = None, ) -> MDState: - """Perform one complete Langevin dynamics integration step. - - This function implements the BAOAB splitting scheme for Langevin dynamics, - which provides accurate sampling of the canonical ensemble. The integration - sequence is: - 1. Half momentum update using forces (B step) - 2. Half position update using updated momenta (A step) - 3. Full stochastic update with noise and friction (O step) - 4. Half position update using updated momenta (A step) - 5. Half momentum update using new forces (B step) + r"""Perform one complete Langevin dynamics integration step using the BAOAB scheme. + + Implements the BAOAB splitting of the Langevin equation from + Leimkuhler & Matthews (2013, 2016) [2]_. The Langevin SDE is: + + .. math:: + + d\mathbf{q} &= M^{-1}\mathbf{p}\,dt \\ + d\mathbf{p} &= -\nabla U(\mathbf{q})\,dt + - \gamma\,\mathbf{p}\,dt + + \sigma M^{1/2}\,d\mathbf{W} + + where :math:`\sigma = \sqrt{2\gamma k_BT}` (fluctuation-dissipation relation). + + **BAOAB splitting** (B = kick, A = drift, O = Ornstein-Uhlenbeck): + + .. math:: + + \text{B:}\quad \mathbf{p} &\leftarrow \mathbf{p} + + \tfrac{\Delta t}{2}\,\mathbf{F}(\mathbf{q}) \\ + \text{A:}\quad \mathbf{q} &\leftarrow \mathbf{q} + + \tfrac{\Delta t}{2}\,M^{-1}\mathbf{p} \\ + \text{O:}\quad \mathbf{p} &\leftarrow + c_1\,\mathbf{p} + c_2\,M^{1/2}\mathbf{R}, + \quad \mathbf{R}\sim\mathcal{N}(0,I) \\ + \text{A:}\quad \mathbf{q} &\leftarrow \mathbf{q} + + \tfrac{\Delta t}{2}\,M^{-1}\mathbf{p} \\ + \text{B:}\quad \mathbf{p} &\leftarrow \mathbf{p} + + \tfrac{\Delta t}{2}\,\mathbf{F}(\mathbf{q}) + + with :math:`c_1 = e^{-\gamma\Delta t}` and + :math:`c_2 = \sqrt{k_BT\,(1-c_1^2)}`. + + **Variable mapping (equation -> code):** + + ============================================ ============================ + Equation symbol Code variable + ============================================ ============================ + :math:`\mathbf{q}` (positions) ``state.positions`` + :math:`\mathbf{p}` (momenta) ``state.momenta`` + :math:`M` (mass matrix) ``state.masses`` + :math:`\mathbf{F}` (forces) ``state.forces`` + :math:`\gamma` (friction coefficient) ``gamma`` + :math:`k_BT` (thermal energy) ``kT`` + :math:`\Delta t` (timestep) ``dt`` + :math:`c_1` ``c1`` in ``_ou_step`` + :math:`c_2` ``c2`` in ``_ou_step`` + ============================================ ============================ Args: state: Current system state containing positions, momenta, forces @@ -168,13 +206,12 @@ def nvt_langevin_step( MDState: Updated state after one complete Langevin step with new positions, momenta, forces, and energy - Notes: - - Uses BAOAB splitting scheme for Langevin dynamics - - Preserves detailed balance for correct NVT sampling - - Handles periodic boundary conditions if enabled in state - - Friction coefficient gamma controls the thermostat coupling strength - - Weak coupling (small gamma) preserves dynamics but with slower thermalization - - Strong coupling (large gamma) faster thermalization but may distort dynamics + References: + .. [2] Leimkuhler B, Matthews C. "Efficient molecular dynamics using geodesic + integration and solvent-solute splitting." Proc. R. Soc. A 472: 20160138 + (2016). Original BAOAB analysis in: Leimkuhler B, Matthews C. "Rational + construction of stochastic numerical methods for molecular sampling." + Appl. Math. Res. Express 2013(1), 34-56 (2013). """ device, dtype = model.device, model.dtype @@ -290,7 +327,7 @@ def nvt_nose_hoover_init( dt_tensor = torch.as_tensor(dt, device=state.device, dtype=state.dtype) kT_tensor = torch.as_tensor(kT, device=state.device, dtype=state.dtype) tau_tensor = torch.as_tensor( - 100.0 * dt_tensor if tau is None else tau, device=state.device, dtype=state.dtype + 10.0 * dt_tensor if tau is None else tau, device=state.device, dtype=state.dtype ) # Create thermostat functions @@ -314,11 +351,9 @@ def nvt_nose_hoover_init( masses=state.masses, momenta=momenta, system_idx=state.system_idx ) - # Calculate degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) - dof_per_system = ( - n_atoms_per_system * state.positions.shape[-1] - ) # n_atoms * n_dimensions + # Calculate degrees of freedom per system (subtract 3 for COM motion, + # matching LAMMPS compute_temp which uses dof = 3N - 3) + dof_per_system = state.get_number_of_degrees_of_freedom() - 3 # Initialize state return NVTNoseHooverState.from_state( @@ -340,13 +375,63 @@ def nvt_nose_hoover_step( dt: float | torch.Tensor, kT: float | torch.Tensor, ) -> NVTNoseHooverState: - """Perform one complete Nose-Hoover chain integration step. - - This function performs one integration step for an NVT system using a Nose-Hoover - chain thermostat. The integration scheme is time-reversible and conserves an - extended energy quantity. If the center of mass motion is removed initially, - it remains removed throughout the simulation, so the degrees of freedom decreases - by 3. + r"""Perform one complete Nose-Hoover chain (NHC) integration step. + + Implements the NHC thermostat from Martyna et al. (1996) [3]_ with + Suzuki-Yoshida integration of the chain variables. + + **Equations of motion** (Martyna et al. 1996, Eqs. 13-18): + + .. math:: + + \dot{\mathbf{r}}_i &= \mathbf{p}_i / m_i \\ + \dot{\mathbf{p}}_i &= \mathbf{F}_i + - \frac{p_{\xi_1}}{Q_1}\,\mathbf{p}_i \\ + \dot{\xi}_j &= p_{\xi_j} / Q_j \\ + \dot{p}_{\xi_1} &= \bigl(2K - N_f k_BT\bigr) + - \frac{p_{\xi_2}}{Q_2}\,p_{\xi_1} \\ + \dot{p}_{\xi_j} &= \left(\frac{p_{\xi_{j-1}}^2}{Q_{j-1}} + - k_BT\right) - \frac{p_{\xi_{j+1}}}{Q_{j+1}}\,p_{\xi_j} + \quad (j = 2,\ldots,M{-}1) \\ + \dot{p}_{\xi_M} &= \frac{p_{\xi_{M-1}}^2}{Q_{M-1}} - k_BT + + where :math:`K = \sum_i p_i^2/(2m_i)` is the kinetic energy, + :math:`N_f = 3N - 3` the degrees of freedom, and :math:`Q_j = k_BT\tau^2` + (with :math:`Q_1 = N_f k_BT\tau^2`) are the chain masses. + + **Symmetric propagator** (Trotter factorization): + + .. math:: + + e^{i\mathcal{L}\Delta t} = e^{i\mathcal{L}_{\text{NHC}}\Delta t/2} + \;e^{i\mathcal{L}_{\text{VV}}\Delta t} + \;e^{i\mathcal{L}_{\text{NHC}}\Delta t/2} + + where :math:`i\mathcal{L}_{\text{VV}}` is the velocity Verlet propagator + and :math:`i\mathcal{L}_{\text{NHC}}` integrates the chain with + :math:`n_c \times n_{\text{sy}}` sub-steps (Suzuki-Yoshida decomposition). + + **Variable mapping (equation -> code):** + + ============================================ ============================ + Equation symbol Code variable + ============================================ ============================ + :math:`\mathbf{r}_i` (positions) ``state.positions`` + :math:`\mathbf{p}_i` (momenta) ``state.momenta`` + :math:`m_i` (masses) ``state.masses`` + :math:`\mathbf{F}_i` (forces) ``state.forces`` + :math:`\xi_j` (chain positions) ``state.chain.positions`` + :math:`p_{\xi_j}` (chain momenta) ``state.chain.momenta`` + :math:`Q_j` (chain masses) ``state.chain.masses`` + :math:`K` (kinetic energy) ``state.chain.kinetic_energy`` + :math:`N_f` (degrees of freedom) ``state.chain.degrees_of_freedom`` + :math:`\tau` (relaxation time) ``state.chain.tau`` + :math:`k_BT` (thermal energy) ``kT`` + :math:`\Delta t` (timestep) ``dt`` + :math:`M` (chain length) ``chain_length`` + :math:`n_c` (chain substeps) ``chain_steps`` + :math:`n_{\text{sy}}` (SY steps) ``sy_steps`` + ============================================ ============================ Args: state: Current system state containing positions, momenta, forces, and chain @@ -357,13 +442,10 @@ def nvt_nose_hoover_step( Returns: Updated state after one complete Nose-Hoover step - Notes: - Integration sequence: - 1. Update chain masses based on target temperature - 2. First half-step of chain evolution - 3. Full velocity Verlet step - 4. Update chain kinetic energy - 5. Second half-step of chain evolution + References: + .. [3] Martyna, G. J., Tuckerman, M. E., Tobias, D. J. & Klein, M. L. + "Explicit reversible integrators for extended systems dynamics." + Molecular Physics 87(5), 1117-1157 (1996). """ # Get chain functions from state chain_fns = state._chain_fns # noqa: SLF001 @@ -412,7 +494,6 @@ def nvt_nose_hoover_invariant( useful for validating the thermostat implementation. Args: - energy_fn: Function that computes system potential energy given positions state: Current state of the system including chain variables kT: Target temperature in energy units @@ -431,9 +512,8 @@ def nvt_nose_hoover_invariant( masses=state.masses, momenta=state.momenta, system_idx=state.system_idx ) - # Get system degrees of freedom per system - n_atoms_per_system = torch.bincount(state.system_idx) - dof = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dimensions + # Get system degrees of freedom per system (3N - 3 for COM correction) + dof = state.get_number_of_degrees_of_freedom() # Start with system energy e_tot = e_pot + e_kin @@ -626,12 +706,53 @@ def nvt_vrescale_step( kT: float | torch.Tensor, tau: float | torch.Tensor | None = None, ) -> NVTVRescaleState: - """Perform one complete V-Rescale dynamics integration step. + r"""Perform one complete V-Rescale (CSVR) dynamics integration step. + + Implements canonical sampling through velocity rescaling from + Bussi, Donadio & Parrinello (2007) [1]_. + + **Stochastic differential equation** for kinetic energy (Eq. 7 of [1]_): - This function implements the canonical sampling through velocity rescaling (V-Rescale) - thermostat combined with velocity Verlet integration. The V-Rescale thermostat samples - the canonical distribution by rescaling velocities with a properly chosen random - factor that ensures correct canonical sampling. + .. math:: + + dK = \frac{\bar{K} - K}{\tau}\,dt + + 2\sqrt{\frac{K\bar{K}}{N_f\tau}}\,dW + + where :math:`\bar{K} = N_f k_BT/2` is the target kinetic energy. + + **Discrete rescaling factor** :math:`\alpha^2 = K'/K` (Eq. 22 of [1]_): + + .. math:: + + \alpha^2 = e^{-\Delta t/\tau} + + \frac{\bar{K}}{N_f K}\bigl(1-e^{-\Delta t/\tau}\bigr) + \Bigl(R_1^2 + \sum_{i=2}^{N_f} R_i^2\Bigr) + + 2\,e^{-\Delta t/(2\tau)} + \sqrt{\frac{\bar{K}}{N_f K} + \bigl(1-e^{-\Delta t/\tau}\bigr)}\;R_1 + + where :math:`R_1 \sim \mathcal{N}(0,1)` and + :math:`\sum_{i=2}^{N_f} R_i^2 \sim \text{Gamma}\bigl((N_f-1)/2,\,2\bigr)`. + Momenta are then rescaled as :math:`\mathbf{p} \leftarrow \alpha\,\mathbf{p}`. + + **Variable mapping (equation -> code):** + + ============================================ ============================ + Equation symbol Code variable + ============================================ ============================ + :math:`K` (kinetic energy) ``KE_old`` + :math:`\bar{K}` (target KE) ``KE_new`` + :math:`N_f` (degrees of freedom) ``dof`` + :math:`\tau` (relaxation time) ``tau`` + :math:`k_BT` (thermal energy) ``kT`` + :math:`\Delta t` (timestep) ``dt`` + :math:`e^{-\Delta t/\tau}` ``c1`` + :math:`(1-c_1)\bar{K}/(N_f K)` ``c2`` + :math:`R_1` ``r1`` + :math:`\sum_{i=2}^{N_f} R_i^2` ``r2`` + :math:`\alpha^2` (scale factor) ``scale`` + :math:`\alpha` (velocity rescaling) ``lam`` + ============================================ ============================ Args: model: Neural network model that computes energies and forces. @@ -647,19 +768,13 @@ def nvt_vrescale_step( MDState: Updated state after one complete V-Rescale step with new positions, momenta, forces, and energy - Notes: - - Uses V-Rescale thermostat for proper canonical ensemble sampling - - Unlike Berendsen thermostat, V-Rescale samples the true canonical distribution - - Integration sequence: V-Rescale rescaling + Velocity Verlet step - - The rescaling factor follows the distribution derived in Bussi et al. - References: - Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity rescaling." - The Journal of chemical physics, 126(1), 014101 (2007). + .. [1] Bussi G, Donadio D, Parrinello M. "Canonical sampling through velocity + rescaling." J. Chem. Phys. 126(1), 014101 (2007). """ device, dtype = model.device, model.dtype - tau = torch.as_tensor(100 * dt if tau is None else tau, device=device, dtype=dtype) + tau = torch.as_tensor(10 * dt if tau is None else tau, device=device, dtype=dtype) dt = torch.as_tensor(dt, device=device, dtype=dtype) kT = torch.as_tensor(kT, device=device, dtype=dtype) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 72c25c07..e2028f80 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -751,7 +751,6 @@ def static( Args: system (StateLike): Input system to calculate properties for model (ModelInterface): Neural network model module - unit_system (UnitSystem): Unit system for energy and forces trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking trajectory. If a dict, will be passed to the TrajectoryReporter constructor and must include at least the "filenames" key. Any prop diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 5a73d580..60c3a2f2 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -283,9 +283,6 @@ def report( model (ModelInterface, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. - write_to_file (bool, optional): Whether to write the state to the trajectory - files. Defaults to True. Should only be set to `False` if the props - are being collected separately. Returns: list[dict[str, torch.Tensor]]: Map of property names to tensors for each @@ -835,9 +832,6 @@ def get_steps( Args: name (str): Name of the array - start (int, optional): Starting frame index. Defaults to None. - stop (int, optional): Ending frame index (exclusive). Defaults to None. - step (int, optional): Step size between frames. Defaults to 1. Returns: np.ndarray: Array of step numbers with shape [n_selected_frames] diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index bdc41f1e..35fbaa45 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1260,12 +1260,18 @@ def unwrap_positions( dfrac = frac[1:] - frac[:-1] dfrac -= torch.round(dfrac) - dcart = torch.einsum("tni,tnij->tnj", dfrac, box_atoms[:-1]) + # Reconstruct unwrapped fractional trajectory + unwrapped_frac = torch.empty_like(frac) + unwrapped_frac[0] = frac[0] + unwrapped_frac[1:] = torch.cumsum(dfrac, dim=0) + frac[0] + + # Convert back to Cartesian using each frame's cell + return torch.einsum("tni,tnij->tnj", unwrapped_frac, box_atoms) else: raise ValueError("box must have shape (n_systems,3,3) or (T,n_systems,3,3)") - # Cumulative reconstruction + # Cumulative reconstruction (constant cell path) unwrapped = torch.empty_like(positions) unwrapped[0] = positions[0] unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 81d22078..1a360da1 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -241,7 +241,6 @@ def random_packed_structure( position when computing minimum distances. device: PyTorch device for calculations (CPU/GPU). dtype: PyTorch data type for numerical precision. - log: List to store positions at each iteration. Returns: FIREState: The optimized structure state containing positions, forces,