Skip to content

Adding training functionalities to Toolkit#108

Open
laserkelvin wants to merge 328 commits into
NVIDIA:mainfrom
laserkelvin:training-epic
Open

Adding training functionalities to Toolkit#108
laserkelvin wants to merge 328 commits into
NVIDIA:mainfrom
laserkelvin:training-epic

Conversation

@laserkelvin

Copy link
Copy Markdown
Collaborator

ALCHEMI Toolkit Pull Request

Description

This PR introduces the core functionalities required to support training and fine-tuning of models in nvalchemi-toolkit.

This PR is still a WIP - do not merge!

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Performance improvement
  • Documentation update
  • Refactoring (no functional changes)
  • CI/CD or infrastructure change

Related Issues

Changes Made

  • create_model_spec methods and dynamic pydantic model creation for pickle-less serialization of configuration
  • Adds a few base loss functions, the general loss abstraction including individual losses and a composed loss function. The latter can be adjusted with weight scheduling, allowing the relative weighting of different losses to be adjusted over the course of training
  • Adds a TrainingStrategy pydantic model as a recipe validation and loop executor. The execution is highly modular and extendible, allowing for (hopefully) arbitrarily complex training workflows to be built, and not limited to MLIPs
  • Adds a FineTuningStrategy that specializes TrainingStrategy for...fine-tuning workflows by making pre-existing checkpoints and layer addition/modification integral to the workflow
  • Adds data loading optimizations; the main changes is addition of "batched" pre-fetching, which amortizes I/O for non-contiguous data samples. This is crucial for Zarr performance when shuffling data
  • Adds multidataset support, with a "meta" sampler that allows users to implement different cross-dataset sampling strategies (e.g. to account for dataset size imbalances)
  • Adds several training-related hooks, such as model averaging, mixed precision, checkpointing

Testing

  • Unit tests pass locally (make pytest)
  • Linting passes (make lint)
  • New tests added for new functionality meets coverage expectations?

Checklist

  • I have read and understand the Contributing Guidelines
  • I have updated the CHANGELOG.md
  • I have performed a self-review of my code
  • I have added docstrings to new functions/classes
  • I have updated the documentation (if applicable)

Additional Notes

Tip

This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.

laserkelvin and others added 30 commits May 14, 2026 15:23
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Brings in 5 upstream commits from main:
- a85db34 Refactor hook contexts (NVIDIA#93) - splits HookContext into base +
  DynamicsContext + TrainContext
- 84d8119 chore: bumping torch minimum version to 2.8 (NVIDIA#85)
- 8f7e628 fix(dynamics): MTK NPT/NPH barostat thermostat coupling (NVIDIA#90)
- 001f1cb fix(models): tensile-positive stress convention (NVIDIA#87)
- 7fe7756 fix(models): merge force and stress autograd (NVIDIA#88)

This propagates the new TrainContext shape to all stacked PRs in the
training-epic series (#4, #5, #6, #7, #8, #9). Stacked PR branches will
need to be rebased or merged on top of this updated training-epic.
Brings training-epic up to date with origin/main on this branch:
- a85db34 Refactor hook contexts (NVIDIA#93) - HookContext/DynamicsContext/TrainContext split
- 84d8119 chore: torch>=2.8 (NVIDIA#85)
- 8f7e628 fix(dynamics): MTK NPT/NPH barostat (NVIDIA#90)
- 001f1cb fix(models): tensile-positive stress (NVIDIA#87)
- 7fe7756 fix(models): merge force and stress autograd (NVIDIA#88)

Replaces an earlier direct origin/main merge to ensure a single merge base
between this branch and training-epic, so the PR diff displays the true
contribution scope (training primitives only, not the upstream merge churn).

# Conflicts:
#	nvalchemi/hooks/_context.py
#	test/hooks/test_context.py
Add supporting functions for upcoming `TrainingStrategy`
Bring in PR #4 (training runtime primitives) and PR NVIDIA#93 (hook context refactor).

Conflict resolution:

- nvalchemi/hooks/_context.py: take upstream's split (HookContext base + DynamicsContext / TrainContext subclasses); keep our additions on TrainContext only:
  * grad_scaler: torch.amp.GradScaler | None = None
  * optimizers / lr_schedulers default to empty list (field(default_factory=list)) instead of None, so the orchestrator's gated-op consumers can iterate without None guards.
- test/hooks/test_context.py: take upstream verbatim, flip optimizers/lr_schedulers default assertions to == [], cover grad_scaler default + populated cases, and add test_optimizers_default_is_independent_per_instance to guard against shared-list aliasing.

Strategy + orchestrator wiring:
- TrainingStrategy._build_context now returns TrainContext and passes model=self.models["main"] to preserve the legacy ctx.model alias for hooks that read a single main model (upstream PR NVIDIA#93 decoupled model from models, so we re-establish the alias at the producer rather than via a property).
- TrainingUpdateHook / TrainingUpdateOrchestrator type hints narrowed from HookContext to TrainContext (no runtime change; TrainContext IS-A HookContext).

Verification: 146 targeted tests / 462 training / 1071 hooks+dynamics passing; make lint + make interrogate green.
…strategy-orchestration

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Adds archetypal hook that drives torch.amp autocast and GradScaler via
the DO_BACKWARD and DO_OPTIMIZER_STEP stages, keeping TrainingStrategy
AMP-agnostic. Supports fp32/bf16/fp16 with skip-safe scheduler gating
and accepts both torch.dtype objects and canonical dtype strings.

Tests consolidated via precision x device parametrized fixtures plus a
CUDA end-to-end case that exercises real autocast and GradScaler
without mocking.
Pulls AtomicData/Batch/Dataset/model/optimizer/strategy construction into test/training/conftest.py as pure-value fixtures backed by private _build_* helpers, removing the cross-module import between test_mixed_precision and test_strategy. Adds an autouse fixture that seeds torch (and CUDA when visible) to 0 before each test, dropping 20 inline torch.manual_seed(0) calls. Training-fn symbols stay in test_strategy.py to preserve spec round-trip identity assertions.
Introduce a hook that lazily maintains a torch.optim.swa_utils.AveragedModel
over a selected training model at TrainingStage.AFTER_OPTIMIZER_STEP. Keeps
the hook a pure observer (no backward, no grad/optimizer/scheduler mutation,
no ctx.models mutation) and exposes averaged_model / get_averaged_model for
explicit eval and checkpoint use. state_dict / load_state_dict land in a
follow-up step.
…ration

Add `TrainingStrategy` orchestration
laserkelvin and others added 14 commits June 12, 2026 15:11
Add shared profiling hooks for training and dynamics
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
update pipeline to be compatible with ema
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Fix unweighted validation loss reporting
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
left=_loss_weight_to_spec(weight.left),
right=_loss_weight_to_spec(weight.right),
)
if hasattr(weight, "model_dump"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice that for objects that lack model_dump(), strategy.json can be missing/problematic and checkpoint resume can fail. Maybe raise an explicit error when a weight schedule isn't spec-serializable (e.g., require subclassing _BaseWeightSchedule) when initializing?

laserkelvin and others added 13 commits June 15, 2026 21:07
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Introduce a user-facing guide for the training API, walking from a minimal
script through the strategy lifecycle, configuration, counters, the per-batch
forward/loss/backward/update path, optimizer orchestration and update hooks,
validation and logging, and checkpointing.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Distill the prose style of the training strategy guide into the distributed
docs: an intention-first preamble, bridge-and-motivate section openings, and a
note for the single-process no-op behavior. Also correct the manager
initialization ordering and tighten the data loader/sampler walkthrough.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Rewrite the checkpoint overview and manual save/restart walkthrough for
clarity, and add a "Serialization scope" section that spells out what a
checkpoint manifest can and cannot embed: importable reconstruction specs and
JSON/registered-type arguments versus non-importable callables (the training
function being the common case), non-serializable spec arguments, and runtime
objects such as hooks and dataloaders that must be re-supplied at load time.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- validation.rst: fix a dropped-word sentence and an em-dash spacing typo.
- hooks.rst: note that most users register concrete TrainingUpdateHook
  instances directly, and that the orchestrator is created automatically.
- losses.rst: add a short composition overview and a runnable example ahead
  of the autosummary tables.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
List the training strategy guide in the userguide landing page and add it to
the toctree so it is reachable from navigation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Convert prose triple-hyphens to spaced unicode em-dashes so the hooks guide
matches the em-dash convention used across the other userguide pages. The
ASCII em-dash inside the code-block comment is left as-is.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Call out that training examples are labeled explicitly in the intermediate
example list, and flag splitting training workflows into a dedicated section
if the collection grows.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>

# Conflicts:
#	CHANGELOG.md
@laserkelvin laserkelvin marked this pull request as ready for review June 17, 2026 00:19
@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test 6f08683

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants