Skip to content

test: add SDPA gradient health tests (fixes #44928)#47024

Open
Lemniscate-world wants to merge 1 commit into
huggingface:mainfrom
Lemniscate-world:main
Open

test: add SDPA gradient health tests (fixes #44928)#47024
Lemniscate-world wants to merge 1 commit into
huggingface:mainfrom
Lemniscate-world:main

Conversation

@Lemniscate-world

@Lemniscate-world Lemniscate-world commented Jul 2, 2026

Copy link
Copy Markdown

CI

Fixes #44928

Summary

Adds gradient health validation tests for scaled_dot_product_attention with dense attention masks. These tests catch gradient explosion bugs such as #44928, where 3D position_ids force SDPA Math fallback, causing BF16 collapse and NaN gradients in Qwen3.5 RLHF training.

Tests added

  • test_sdpa_finite_gradient_standard: SDPA must produce finite gradients with standard inputs
  • test_sdpa_finite_gradient_dense_mask: Dense causal mask must not cause NaN/exploding gradients
  • test_sdpa_gradient_scaling_consistency: Gradient scales approximately linearly with query scaling

Motivation

Issue #44928 reports catastrophic gradient explosion (NaN) when using Qwen3.5 with 3D position_ids that force SDPA Math backend fallback under BF16. These tests establish baseline gradient health expectations and would have caught the regression.

Detection

These patterns are monitored by NeuralDBG, a causal inference engine for DL training, via its gradient health transition events.

@Lemniscate-world

Copy link
Copy Markdown
Author

Detected via NeuralDBG gradient health monitoring. The test patterns here (finite check, dense mask safety, scaling consistency) mirror the gradient_health_transition events that NeuralDBG captures at runtime.

CC @ArthurZucker for visibility (related to SDPA attention).

@github-actions

github-actions Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28593785482
Result: failure | Grafana metrics are not available yet.

Code quality check failed: test jobs were skipped. Fix the code quality issues and push again to run tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Catastrophic gradient explosion (NaN) in RLHF with Qwen3.5 due to 3D position_ids forcing SDPA Math fallback and BF16 collapse

1 participant