Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/benchmarking/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymatgen.io.ase import AseAtomsAdaptor

import torch_sim as ts
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger


Expand Down Expand Up @@ -57,7 +57,7 @@
def load_mace_model(device: torch.device) -> MaceModel:
"""Load MACE model for benchmarking."""
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype="float64",
device=str(device),
Expand Down
20 changes: 11 additions & 9 deletions examples/scripts/1_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
"""

# /// script
# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/CompRhys/mace.git@main",
# ]
# ///

import itertools
Expand All @@ -18,7 +21,7 @@

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger


Expand All @@ -33,9 +36,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones Model - Simple Classical Potential
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Lennard-Jones Model")
log.info("=" * 70)


# Create face-centered cubic (FCC) Argon
# 5.26 Å is a typical lattice constant for Ar
Expand Down Expand Up @@ -118,13 +121,13 @@
# ============================================================================
# SECTION 2: MACE Model - Machine Learning Potential (Batched)
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: MACE Model with Batched Input")
log.info("=" * 70)


# Load the raw model from the downloaded model
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
Expand Down Expand Up @@ -213,6 +216,5 @@
log.info(f"Max forces difference: {forces_diff}")
log.info(f"Max stress difference: {stress_diff}")

log.info("=" * 70)

log.info("Introduction examples completed!")
log.info("=" * 70)
43 changes: 22 additions & 21 deletions examples/scripts/2_structural_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"""

# /// script
# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/CompRhys/mace.git@main",
# ]
# ///

import itertools
Expand All @@ -21,7 +24,7 @@

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger
from torch_sim.units import UnitConversion

Expand All @@ -41,9 +44,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones FIRE Optimization
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Lennard-Jones FIRE Optimization")
log.info("=" * 70)


# Set up the random number generator
generator = torch.Generator(device=device)
Expand Down Expand Up @@ -127,13 +130,13 @@
# ============================================================================
# SECTION 2: Batched MACE FIRE Optimization (Atomic Positions Only)
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: Batched MACE FIRE - Positions Only")
log.info("=" * 70)


# Load MACE model
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
Expand Down Expand Up @@ -189,9 +192,9 @@
# ============================================================================
# SECTION 3: Batched MACE Gradient Descent Optimization
# ============================================================================
log.info("=" * 70)

log.info("SECTION 3: Batched MACE Gradient Descent")
log.info("=" * 70)


# Reset structures with new perturbations
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -222,9 +225,9 @@
# ============================================================================
# SECTION 4: Unit Cell Filter with Gradient Descent
# ============================================================================
log.info("=" * 70)

log.info("SECTION 4: Unit Cell Filter with Gradient Descent")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -278,9 +281,9 @@
# ============================================================================
# SECTION 5: Unit Cell Filter with FIRE
# ============================================================================
log.info("=" * 70)

log.info("SECTION 5: Unit Cell Filter with FIRE")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -330,9 +333,9 @@
# ============================================================================
# SECTION 6: Frechet Cell Filter with FIRE
# ============================================================================
log.info("=" * 70)

log.info("SECTION 6: Frechet Cell Filter with FIRE")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -392,9 +395,9 @@
# ============================================================================
# SECTION 7: Batched MACE L-BFGS
# ============================================================================
log.info("=" * 70)

log.info("SECTION 7: Batched MACE L-BFGS")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
Expand Down Expand Up @@ -425,9 +428,9 @@
# ============================================================================
# SECTION 8: Batched MACE BFGS
# ============================================================================
log.info("=" * 70)

log.info("SECTION 8: Batched MACE BFGS")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
Expand Down Expand Up @@ -455,6 +458,4 @@
log.info(f"Final energies: {[energy.item() for energy in state.energy]} eV")


log.info("=" * 70)
log.info("Structural optimization examples completed!")
log.info("=" * 70)
32 changes: 17 additions & 15 deletions examples/scripts/3_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"""

# /// script
# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/CompRhys/mace.git@main",
# ]
# ///

import itertools
Expand All @@ -21,7 +24,7 @@

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger
from torch_sim.units import MetalUnits as Units

Expand Down Expand Up @@ -52,9 +55,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones NVE (Microcanonical Ensemble)
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Lennard-Jones NVE Simulation")
log.info("=" * 70)


# Create face-centered cubic (FCC) Argon
a_len = 5.26 # Lattice constant
Expand Down Expand Up @@ -139,13 +142,13 @@
# ============================================================================
# SECTION 2: MACE NVE Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: MACE NVE Simulation")
log.info("=" * 70)


# Load MACE model
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
Expand Down Expand Up @@ -205,9 +208,9 @@
# ============================================================================
# SECTION 3: MACE NVT Langevin Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 3: MACE NVT Langevin Simulation")
log.info("=" * 70)


# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -252,9 +255,9 @@
# ============================================================================
# SECTION 4: MACE NVT Nose-Hoover Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 4: MACE NVT Nose-Hoover Simulation")
log.info("=" * 70)


# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -294,9 +297,9 @@
# ============================================================================
# SECTION 5: MACE NPT Nose-Hoover Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 5: MACE NPT Nose-Hoover Simulation")
log.info("=" * 70)


# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -404,6 +407,5 @@
)
log.info(f"Final pressure: {final_pressure.item():.4f} eV/ų")

log.info("=" * 70)

log.info("Molecular dynamics examples completed!")
log.info("=" * 70)
Loading
Loading