Reimplement dropout_weights() correctly for log-space weights#92
Merged
Conversation
The previous implementation held `weights` in log space but set masked entries to literal 0, which is `exp(0) = 1` in linear space -- the opposite of dropping. It then normalised by dividing masked log-weights by their sum and multiplying by `total_weight`, which is not a meaningful operation on logs. On realistic survey-weight scales (hundreds to thousands, log ~6-8) `masked_weights.sum()` could cross zero, producing Inf/NaN that crashed the forward pass via the NaN guard in loss(). Rewrite the function as standard inverted dropout performed in log space: dropped entries go to `-inf` (exp = 0) and surviving entries are shifted by `-log(1-p)` so the expected linear-space sum is preserved. Hoist the helper to module scope so it is directly testable and reject out-of-range probabilities explicitly. Adds tests/test_dropout_regression.py covering p=0 identity, p=1 zeros-all, expected-sum preservation on realistic-scale weights, approximate drop fraction, input validation, and an end-to-end smoke test that training with dropout on realistic-scale weights no longer poisons the loss. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
MaxGhenis
commented
Apr 17, 2026
Contributor
Author
MaxGhenis
left a comment
There was a problem hiding this comment.
LGTM (cannot self-approve; posting as comment).
Correct inverted-dropout semantics in log space:
p=0returns input unchanged (identity).p=1returns-inftensor →exp= 0 everywhere (all dropped).- Masked entries go to
-inf(not0, which was the original bug). - Survivors shifted by
-log(1-p)soE[exp(dropout(logw))] = exp(logw)in linear space, matching standard inverted dropout. pvalidated to[0, 1]with a clearValueError.
Tests cover identity, full dropout, approximate drop fraction, expected-sum preservation on realistic (hundreds-to-thousands) weight scales, input validation, and end-to-end reweight() with dropout_rate > 0. Good coverage of the original bug's exact failure mode (NaN/Inf on realistic scales).
Hoisting dropout_weights to module scope is also a small structural win — now directly unit-testable.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Finding #2 (HIGH) in the bug-hunt report.
dropout_weightsheldweightsin log space but set masked entries to literal0, which isexp(0) = 1in linear space — the opposite of dropping. It then normalised by dividing masked log-weights by their sum and multiplying bytotal_weight, which is not a meaningful operation on logs. On realistic survey-weight scales (hundreds to thousands,log ~6–8),masked_weights.sum()could cross zero, producing Inf/NaN that crashed the forward pass via the NaN guard inloss().Rewritten as standard inverted dropout performed in log space: dropped entries go to
-inf(exp = 0) and surviving entries are shifted by-log(1-p)so the expected linear-space sum is preserved. Hoisted to module scope so it is directly testable, and invalid probabilities now raise.Test plan
tests/test_dropout_regression.py:p=0identity,p=1zeroes all linear-space weights, expected-sum preservation on realistic-scale weights (Monte Carlo tolerance), approximate drop fraction, input validation, and end-to-end smoke test that training with dropout no longer poisons the loss.uv run pytest tests -x -q-> 21 passed).🤖 Generated with Claude Code