Skip to content

Guard duplicate-key check and write in safetensors loader with a lock#1627

Open
AMR5210 wants to merge 2 commits into
google:mainfrom
AMR5210:fix-safetensors-dup-key-1259
Open

Guard duplicate-key check and write in safetensors loader with a lock#1627
AMR5210 wants to merge 2 commits into
google:mainfrom
AMR5210:fix-safetensors-dup-key-1259

Conversation

@AMR5210

@AMR5210 AMR5210 commented Jun 27, 2026

Copy link
Copy Markdown

Resolves #1259

This fixes a TOCTOU race in load_and_create_model_orig
(tunix/models/safetensors_loader.py). process_key runs across a
ThreadPoolExecutor(max_workers=os.cpu_count()), but the duplicate-key check
if jax_key_mapped in file_loaded_tensors and the following write were not
protected by a lock. Two threads whose source keys map to the same jax key can
both pass the check before either writes, so the second silently overwrites the
first and the intended ValueError never fires. The model then loads with
incorrect weights and no error.

The fix adds a dict_lock and wraps only the check and write in it. The tensor
read, transform, and dtype cast stay outside the lock, so loading remains
parallel with no meaningful speed impact.

Reference

Colab Notebook
N/A. This is a bug fix to existing loading code and adds no new API.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

AMR5210 added 2 commits June 27, 2026 16:58
Added a lock to prevent race conditions when checking for duplicate keys in loaded tensors.
Add a test to verify that loading a model raises an error on duplicate keys.
@AMR5210

AMR5210 commented Jun 27, 2026

Copy link
Copy Markdown
Author

Opened #1627 to fix this. The fix adds a dict_lock around the duplicate-key
check and write in load_and_create_model_orig so both operations are atomic
under the existing ThreadPoolExecutor. Includes a regression test that
confirms the duplicate is detected deterministically. All 7 tests pass on
Kaggle TPU infrastructure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TOCTOU race condition in concurrent safetensors loading bypasses duplicate key detection

2 participants