Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
assert torch.allclose(
temperatures_tensor[-1],
torch.tensor([300.0096, 299.7024], dtype=dtype),
torch.tensor([290.3553, 289.9699], dtype=dtype),
)

energies_tensor = torch.stack(energies)
Expand Down Expand Up @@ -728,7 +728,7 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
assert torch.allclose(
temperatures_tensor[-1],
torch.tensor([298.2752, 297.9444], dtype=dtype),
torch.tensor([287.5729, 287.1330], dtype=dtype),
)

energies_tensor = torch.stack(energies)
Expand Down
2 changes: 1 addition & 1 deletion torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def init_fn(

Q = (
kT_batched.unsqueeze(-1)
* torch.square(tau_batched).unsqueeze(-1) ** 2
* torch.square(tau_batched).unsqueeze(-1)
* torch.ones((n_systems, chain_length), dtype=dtype, device=device)
)
Q[:, 0] *= degrees_of_freedom
Expand Down
23 changes: 10 additions & 13 deletions torch_sim/integrators/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,9 @@ def _npt_nose_hoover_compute_cell_force(
internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems)

# Compute force on cell coordinate per system
# F = alpha * KE - dU/dV - P*V*d
# F = alpha * (2 * KE) - dU/dV - P*V*d
return (
(alpha * KE_per_system)
(alpha * 2 * KE_per_system)
- (internal_pressure * volume)
- (external_pressure * volume * dim)
)
Expand Down Expand Up @@ -1226,21 +1226,18 @@ def _npt_nose_hoover_inner_step(
volume, volume_to_cell = _npt_nose_hoover_cell_info(state)
cell = volume_to_cell(volume)

# Get model output
state.cell = cell
model_output = model(state)

# First half step: Update momenta
n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems)
alpha = 1 + 1 / n_atoms_per_system # [n_systems]
# alpha = 1 + dim / degrees_of_freedom (3 * natoms - 3)
alpha = 1 + 3 / state.get_number_of_degrees_of_freedom() # [n_systems]

# Reuse stress from previous step since positions and cell unchanged
cell_force_val = _npt_nose_hoover_compute_cell_force(
alpha=alpha,
volume=volume,
positions=positions,
momenta=momenta,
masses=masses,
stress=model_output["stress"],
stress=state.stress,
external_pressure=external_pressure,
system_idx=state.system_idx,
)
Expand Down Expand Up @@ -1406,7 +1403,8 @@ def npt_nose_hoover_init(
)

# Compute total DOF for thermostat initialization and a zero KE placeholder
dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim
dof_per_system = state.get_number_of_degrees_of_freedom() - 3

KE_thermostat = ts.calc_kinetic_energy(
masses=state.masses, momenta=momenta, system_idx=state.system_idx
)
Expand Down Expand Up @@ -1612,13 +1610,12 @@ def npt_nose_hoover_invariant(
)

# Calculate degrees of freedom per system
n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems)
dof_per_system = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dim
dof_per_system = state.get_number_of_degrees_of_freedom()

# Initialize total energy with PE + KE
e_tot = e_pot + e_kin_per_system

# Add thermostat chain contributions (batched per system, DOF = n_atoms * 3)
# Add thermostat chain contributions (batched per system, DOF = 3 * n_atoms - 3)
e_tot += _compute_chain_energy(state.thermostat, kT, e_tot, dof_per_system)

# Add barostat chain contributions (batched per system, DOF = 1)
Expand Down
13 changes: 5 additions & 8 deletions torch_sim/integrators/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,9 @@ def nvt_nose_hoover_init(
masses=state.masses, momenta=momenta, system_idx=state.system_idx
)

# Calculate degrees of freedom per system
n_atoms_per_system = torch.bincount(state.system_idx)
dof_per_system = (
n_atoms_per_system * state.positions.shape[-1]
) # n_atoms * n_dimensions
# Calculate degrees of freedom per system (subtract 3 for COM motion,
# matching LAMMPS compute_temp which uses dof = 3N - 3)
dof_per_system = state.get_number_of_degrees_of_freedom() - 3

# Initialize state
return NVTNoseHooverState.from_state(
Expand Down Expand Up @@ -431,9 +429,8 @@ def nvt_nose_hoover_invariant(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)

# Get system degrees of freedom per system
n_atoms_per_system = torch.bincount(state.system_idx)
dof = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dimensions
# Get system degrees of freedom per system (3N - 3 for COM correction)
dof = state.get_number_of_degrees_of_freedom()

# Start with system energy
e_tot = e_pot + e_kin
Expand Down
Loading