Skip to content

Determinism implementation in nnunet#2871

Open
Luugaaa wants to merge 3 commits intoMIC-DKFZ:masterfrom
Luugaaa:determinism_implementation
Open

Determinism implementation in nnunet#2871
Luugaaa wants to merge 3 commits intoMIC-DKFZ:masterfrom
Luugaaa:determinism_implementation

Conversation

@Luugaaa
Copy link
Copy Markdown

@Luugaaa Luugaaa commented Jul 18, 2025

Hi MIC-DKFZ team,

This PR introduces updates to enable fully deterministic training in nnU-Net, which is crucial for reproducibility in research. The changes include adding a deterministic flag, implementing a seed_everything function, and ensuring the data augmentation pipeline is correctly seeded. I've been working on a similar fix for batchgenerators and believe these changes will work together to make the entire training process reproducible.

Problem

Achieving deterministic behavior in a multi-process environment can be tricky. Even with seeding, sources of randomness can persist, especially in the data augmentation pipeline. The original nnU-Net trainer used a non-deterministic data loader by default and didn't have a straightforward way to enforce reproducibility across all components, including PyTorch, NumPy, and the data loading workers. This could lead to slight variations in training results, even with the same initial seed.

Solution

To address this, I've implemented the following changes:

  1. deterministic Flag: A deterministic boolean flag has been added to the nnUNetTrainer's __init__ method. When set to True, it activates all the changes needed for reproducible training.
  2. seed_everything Function: A new helper function, seed_everything, is called when the deterministic flag is active. This function sets the seeds for random, numpy, and torch, and also configures cudnn for deterministic behavior to eliminate sources of randomness from the GPU.
  3. Seeded Data Loader: The get_dataloaders method now checks the deterministic flag. If True, it uses MultiThreadedAugmenter and passes a unique, generated seed to each worker process. This ensures that the data augmentation pipeline is fully deterministic and produces the same results in the same order for every run. If False, it continues to use the default NonDetMultiThreadedAugmenter. Same in get_training_transforms to enable or not the benchmark.

How It's Tested

To validate these changes, I've developed a determinism test pipeline that sets up a dummy dataset and runs the entire preprocessing and training pipeline for two epochs with the 2d and 3d configuration. The test works as follows:

  1. A dummy dataset is generated, and nnU-Net's preprocessing is run.
  2. The nnUNetTrainer is instantiated with deterministic=True, and a short training session is run. The final checkpoint is saved.
  3. This process is repeated a second time, with a new trainer instance but the same deterministic settings.
  4. The two resulting checkpoints are then compared byte-for-byte to ensure that the model weights and training/validation losses are identical.

With these changes, the trainer now passes this test, confirming that the training process is fully reproducible when the deterministic flag is enabled. This should be a significant help for researchers who need to ensure their results are perfectly reproducible. The changes are self-contained and don't affect the default behavior of the trainer.

Notes :

  • Unfortunately, I can only test on cpu, there is a slight chance cuda could have a different behavior, even if it's seeded.
  • I've run those tests with my deterministic version batchgeneratorsv2, please see this PR

I hope these changes are helpful. Thanks for maintaining this great project, and I look forward to your feedback!

Post Scriptum

Here is a partial output of the determinism test pipeline :

...
2025-07-18 16:41:32.447456: Unable to plot network architecture:
2025-07-18 16:41:32.447667: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:41:32.467535: 
2025-07-18 16:41:32.467720: Epoch 0
2025-07-18 16:41:32.468179: Current learning rate: 0.01
2025-07-18 16:41:36.461223: train_loss -0.3063
2025-07-18 16:41:36.461562: val_loss -0.5951
2025-07-18 16:41:36.461810: Pseudo dice [np.float32(0.9153)]
2025-07-18 16:41:36.461934: Epoch time: 3.99 s
2025-07-18 16:41:36.462014: Yayy! New best EMA pseudo Dice: 0.9153000116348267
2025-07-18 16:41:36.814825: 
2025-07-18 16:41:36.815006: Epoch 1
2025-07-18 16:41:36.815104: Current learning rate: 0.00536
2025-07-18 16:41:40.652067: train_loss -0.4862
2025-07-18 16:41:40.652299: val_loss -0.7087
2025-07-18 16:41:40.652573: Pseudo dice [np.float32(0.9449)]
2025-07-18 16:41:40.652686: Epoch time: 3.84 s
2025-07-18 16:41:40.652763: Yayy! New best EMA pseudo Dice: 0.9182000160217285
2025-07-18 16:41:41.638668: Training done.

...
2025-07-18 16:41:45.915659: Unable to plot network architecture:
2025-07-18 16:41:45.915884: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:41:45.934981: 
2025-07-18 16:41:45.935152: Epoch 0
2025-07-18 16:41:45.935414: Current learning rate: 0.01
2025-07-18 16:41:50.681323: train_loss -0.3063
2025-07-18 16:41:50.681636: val_loss -0.5951
2025-07-18 16:41:50.681895: Pseudo dice [np.float32(0.9153)]
2025-07-18 16:41:50.682026: Epoch time: 4.75 s
2025-07-18 16:41:50.682148: Yayy! New best EMA pseudo Dice: 0.9153000116348267
2025-07-18 16:41:51.014654: 
2025-07-18 16:41:51.014840: Epoch 1
2025-07-18 16:41:51.014943: Current learning rate: 0.00536
2025-07-18 16:41:55.595253: train_loss -0.4862
2025-07-18 16:41:55.595432: val_loss -0.7087
2025-07-18 16:41:55.595511: Pseudo dice [np.float32(0.9449)]
2025-07-18 16:41:55.595585: Epoch time: 4.58 s
2025-07-18 16:41:55.595646: Yayy! New best EMA pseudo Dice: 0.9182000160217285
2025-07-18 16:41:56.397161: Training done.
Finished training run: run2_2d

--- Comparing checkpoints ---
Model weights are identical: True
Training losses are identical: True
Validation losses are identical: TruePASSED: The training process for '2d' is deterministic.
========== FINISHED TEST FOR CONFIGURATION: 2D ==========

...
2025-07-18 16:42:01.014618: Unable to plot network architecture:
2025-07-18 16:42:01.014909: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:42:01.044617: 
2025-07-18 16:42:01.044771: Epoch 0
2025-07-18 16:42:01.045026: Current learning rate: 0.01
2025-07-18 16:42:46.660177: train_loss -0.1917
2025-07-18 16:42:46.660746: val_loss -0.494
2025-07-18 16:42:46.660841: Pseudo dice [np.float32(0.8838)]
2025-07-18 16:42:46.661161: Epoch time: 45.62 s
2025-07-18 16:42:46.661241: Yayy! New best EMA pseudo Dice: 0.8838000297546387
2025-07-18 16:42:47.174816: 
2025-07-18 16:42:47.174935: Epoch 1
2025-07-18 16:42:47.175024: Current learning rate: 0.00536
2025-07-18 16:43:32.276000: train_loss -0.4466
2025-07-18 16:43:32.277203: val_loss -0.6225
2025-07-18 16:43:32.277357: Pseudo dice [np.float32(0.9225)]
2025-07-18 16:43:32.277469: Epoch time: 45.1 s
2025-07-18 16:43:32.277545: Yayy! New best EMA pseudo Dice: 0.8877000212669373
2025-07-18 16:43:33.464729: Training done.

...
2025-07-18 16:43:38.122881: Unable to plot network architecture:
2025-07-18 16:43:38.123304: module 'torch.onnx' has no attribute '_optimize_trace'
2025-07-18 16:43:38.144184: 
2025-07-18 16:43:38.144306: Epoch 0
2025-07-18 16:43:38.144516: Current learning rate: 0.01
2025-07-18 16:44:25.789200: train_loss -0.1917
2025-07-18 16:44:25.789970: val_loss -0.494
2025-07-18 16:44:25.790100: Pseudo dice [np.float32(0.8838)]
2025-07-18 16:44:25.790242: Epoch time: 47.65 s
2025-07-18 16:44:25.790316: Yayy! New best EMA pseudo Dice: 0.8838000297546387
2025-07-18 16:44:26.444514: 
2025-07-18 16:44:26.444639: Epoch 1
2025-07-18 16:44:26.444731: Current learning rate: 0.00536
2025-07-18 16:45:12.907103: train_loss -0.4466
2025-07-18 16:45:12.909298: val_loss -0.6225
2025-07-18 16:45:12.909446: Pseudo dice [np.float32(0.9225)]
2025-07-18 16:45:12.909584: Epoch time: 46.46 s
2025-07-18 16:45:12.909658: Yayy! New best EMA pseudo Dice: 0.8877000212669373
2025-07-18 16:45:14.147640: Training done.
Finished training run: run2_3d_fullres

--- Comparing checkpoints ---
Model weights are identical: True
Training losses are identical: True
Validation losses are identical: TruePASSED: The training process for '3d_fullres' is deterministic.
========== FINISHED TEST FOR CONFIGURATION: 3D_FULLRES ==========


=========================
  DETERMINISM TEST REPORT
=========================
Configuration '2d': ✅ PASSED
Configuration '3d_fullres': ✅ PASSED
=========================

@Luugaaa
Copy link
Copy Markdown
Author

Luugaaa commented Jul 28, 2025

Note : I was finally able to test the determinism fix on CUDA. Although the reproducibility is greatly improved, the training is not deterministic. I've pinpointed the source of the non determinism to something happening in the backward path from batch_id = 1. I don't plan on investigating further.

@FabianIsensee FabianIsensee self-assigned this Oct 16, 2025
@FabianIsensee
Copy link
Copy Markdown
Member

Dear @Luugaaa , thank you for this PR. I appreciate the effort you put into this to make it deterministic. Fun fact, a long time ago there used to be a deterministic flag in nnU-Net which we removed later on.

Wait what?

Blasphemy!

Indeed. The lack of determinism in nnU-Net is intentional. I do not believe determinism is needed. In fact, I think it is dangerous and detrimental for progress. Strange words, no? But there is a reason: Determinism gives you a false sense of security in producing and interpreting results. When you run an experiment once and get +0.5 Dice vs the baseline: is it a real improvement or not? With deteminism you can run an entire series of experiments and collect incremental improvements. And what happens when the seed changes, or determinism breaks? All your nice and tidy improvements collapse like a house of cards. Sure, you could have run multiple seeds to begin with, but who really does that when compute is always the constraint?
When nnU-Net is nondeterministic you cannot reproduce things that do not truly improve nnU-Net. It's a great feature. Yes you get a slightly different result each time you train, but each of these results is a perfectly legitimate outcome of the training pipeline and reflects the natural variability of the underlying problem. Be that data uncertainty, initialization randomness or whatever else. Trust me, this setup prevents you from jumping to conclusions and will make your research a lot more robust!

So I would very much prefer to keep things nondeterministic in nnU-Net. It's not a bug - it's a feature :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants