Skip to content

feat: add NaN weight detection and input validation (#115)#117

Open
KuaaMU wants to merge 1 commit into
z-lab:mainfrom
KuaaMU:feat/nan-weight-detection
Open

feat: add NaN weight detection and input validation (#115)#117
KuaaMU wants to merge 1 commit into
z-lab:mainfrom
KuaaMU:feat/nan-weight-detection

Conversation

@KuaaMU
Copy link
Copy Markdown

@KuaaMU KuaaMU commented May 12, 2026

Summary

Add NaN weight detection and input validation to DFlashDraftModel to help diagnose and prevent issues like #115.

Changes

1. NaN Weight Detection (_check_weights_for_nan)

  • New method that checks all model parameters for NaN values
  • Warns once per forward pass to avoid flooding stderr
  • Suggests re-downloading the checkpoint as the most common fix
  • Helps users debug issues where layer norm weights become NaN

2. Input Validation

  • Validate that noise_embedding and target_hidden are provided
  • Clear error messages for missing required arguments

3. Type Hints

  • Add type hints to apply_rotary_pos_emb function
  • Improve code readability and IDE support

4. Documentation

  • Add comprehensive docstring to DFlashDraftModel class
  • Document the model architecture and configuration parameters

Why This Helps

Issue #115 reports layers.0.input_layernorm.weight became NaN. While the root cause may be in the checkpoint or training process, this detection helps users:

  1. Identify when weights are corrupted
  2. Get clear guidance on how to fix it (re-download checkpoint)
  3. Debug numerical issues during inference

Test Plan

  • py_compile passes ✅
  • No functional changes to the model logic
  • NaN detection only runs when NaN is found (minimal overhead)

AI assistance used: Claude (Anthropic) for code analysis and implementation. All changes reviewed and verified by the author.

Add _check_weights_for_nan() method to DFlashDraftModel that warns
when NaN values are detected in model weights during inference. This
helps diagnose issues like z-lab#115 where layer norm weights become NaN.

Additional improvements:
- Add input validation for required arguments (noise_embedding, target_hidden)
- Add type hints to apply_rotary_pos_emb function
- Add comprehensive docstring to DFlashDraftModel class

The NaN detection runs at the start of each forward pass and only
warns once to avoid flooding stderr. It suggests re-downloading the
checkpoint as the most common fix for corrupted weights.

Closes z-lab#115

Co-authored-by: Claude <claude@anthropic.com>
Signed-off-by: KuaaMU <XCM853629353@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant