From d05796262e7e3a1ecc32a9f30849239f650cacd6 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:29:38 -0400 Subject: [PATCH] Reimplement dropout_weights() correctly for log-space weights The previous implementation held `weights` in log space but set masked entries to literal 0, which is `exp(0) = 1` in linear space -- the opposite of dropping. It then normalised by dividing masked log-weights by their sum and multiplying by `total_weight`, which is not a meaningful operation on logs. On realistic survey-weight scales (hundreds to thousands, log ~6-8) `masked_weights.sum()` could cross zero, producing Inf/NaN that crashed the forward pass via the NaN guard in loss(). Rewrite the function as standard inverted dropout performed in log space: dropped entries go to `-inf` (exp = 0) and surviving entries are shifted by `-log(1-p)` so the expected linear-space sum is preserved. Hoist the helper to module scope so it is directly testable and reject out-of-range probabilities explicitly. Adds tests/test_dropout_regression.py covering p=0 identity, p=1 zeros-all, expected-sum preservation on realistic-scale weights, approximate drop fraction, input validation, and an end-to-end smoke test that training with dropout on realistic-scale weights no longer poisons the loss. Co-Authored-By: Claude Opus 4.7 (1M context) --- changelog.d/fix-dropout-log-space.fixed.md | 1 + src/microcalibrate/reweight.py | 55 ++++++---- tests/test_dropout_regression.py | 115 +++++++++++++++++++++ 3 files changed, 152 insertions(+), 19 deletions(-) create mode 100644 changelog.d/fix-dropout-log-space.fixed.md create mode 100644 tests/test_dropout_regression.py diff --git a/changelog.d/fix-dropout-log-space.fixed.md b/changelog.d/fix-dropout-log-space.fixed.md new file mode 100644 index 0000000..beb2dad --- /dev/null +++ b/changelog.d/fix-dropout-log-space.fixed.md @@ -0,0 +1 @@ +Reimplement dropout_weights() correctly for log-space weights. The previous implementation set masked log-entries to 0 (exp(0) = 1, not dropped) and divided by the sum of logs, which on realistic weight scales could cross zero and inject Inf/NaN into training. The new implementation applies standard inverted dropout: dropped entries go to -inf in log space (and therefore zero in linear space), survivors are scaled by 1/(1-p) so the expected linear-space sum is preserved. diff --git a/src/microcalibrate/reweight.py b/src/microcalibrate/reweight.py index 018ccba..2fe8bd2 100644 --- a/src/microcalibrate/reweight.py +++ b/src/microcalibrate/reweight.py @@ -13,6 +13,42 @@ from .utils.metrics import loss, pct_close +def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor: + """Apply inverted dropout to weights held in log space. + + ``weights`` represents log(w); downstream code computes + ``torch.exp(weights_)`` to recover linear-space weights. Dropping + an entry therefore means sending its log to ``-inf`` so that + ``exp`` returns 0. Surviving entries are scaled by ``1/(1-p)`` in + linear space (equivalently, ``-log(1-p)`` added in log space) so + the expected linear-space sum is preserved, matching standard + inverted dropout semantics. + + Args: + weights (torch.Tensor): Current weights in log space. + p (float): Probability of dropping each weight, in [0, 1]. + + Returns: + torch.Tensor: Weights in log space after applying dropout. + """ + if p == 0: + return weights + if p < 0 or p > 1: + raise ValueError(f"dropout_rate must be in [0, 1]; got {p}.") + if p == 1: + # Everything is dropped: zero all linear-space weights. The + # result has no gradient path back to ``weights`` because every + # entry is a constant -inf; callers must not rely on training + # under full dropout. + return torch.full_like(weights, float("-inf")) + # ``survive_mask`` is True where an entry SURVIVES. + survive_mask = torch.rand_like(weights) >= p + neg_inf = torch.full_like(weights, float("-inf")) + scale = -float(np.log1p(-p)) # == log(1/(1-p)) + scaled = weights + scale + return torch.where(survive_mask, scaled, neg_inf) + + def reweight( original_weights: np.ndarray, estimate_function: Callable[[Tensor], Tensor], @@ -91,25 +127,6 @@ def reweight( f"std: {torch.exp(weights).std():.4f}" ) - def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor: - """Apply dropout to the weights. - - Args: - weights (torch.Tensor): Current weights in log space. - p (float): Probability of dropping weights. - - Returns: - torch.Tensor: Weights after applying dropout. - """ - if p == 0: - return weights - total_weight = weights.sum() - mask = torch.rand_like(weights) < p - masked_weights = weights.clone() - masked_weights[mask] = 0 - masked_weights = masked_weights / masked_weights.sum() * total_weight - return masked_weights - optimizer = torch.optim.Adam([weights], lr=learning_rate) iterator = tqdm(range(epochs), desc="Reweighting progress", unit="epoch") diff --git a/tests/test_dropout_regression.py b/tests/test_dropout_regression.py new file mode 100644 index 0000000..ece12ba --- /dev/null +++ b/tests/test_dropout_regression.py @@ -0,0 +1,115 @@ +"""Regression tests for the log-space dropout bug (finding #2). + +The previous implementation set masked entries in log space to ``0``, +which corresponds to ``exp(0) = 1`` in linear space -- the opposite of +dropping them. It then divided masked log-weights by their sum, which +is not a meaningful normalisation on logs and could cross zero, +producing ``inf`` / ``NaN`` on realistic weight scales. +""" + +import logging + +import numpy as np +import pytest +import torch + +from microcalibrate.reweight import dropout_weights, reweight + + +def test_dropout_p_zero_is_identity() -> None: + """p=0 must return the input tensor unchanged (no dropout).""" + weights = torch.log(torch.tensor([10.0, 100.0, 1000.0])) + result = dropout_weights(weights, 0.0) + assert torch.equal(result, weights) + + +def test_dropout_p_one_zeroes_all_linear_weights() -> None: + """p=1 must zero every linear-space weight.""" + weights = torch.log(torch.tensor([10.0, 100.0, 1000.0])) + result = dropout_weights(weights, 1.0) + linear = torch.exp(result) + assert torch.all(linear == 0) + + +def test_dropout_preserves_sum_in_expectation_on_realistic_scale() -> None: + """On realistic (non-unit) weights dropout must not produce NaN/Inf + and the expected linear-space sum must be preserved. + + Regression: on realistic survey weights (hundreds to thousands), + ``log(w)`` is ~6-8 and the previous normalisation step could cross + zero, yielding ``inf`` or ``NaN``. With inverted dropout, the + expected linear-space sum of the output equals the linear-space + sum of the input. + """ + rng = np.random.default_rng(7) + linear_weights = rng.uniform(100.0, 5000.0, size=500) + log_weights = torch.tensor(np.log(linear_weights), dtype=torch.float32) + + torch.manual_seed(0) + n_trials = 200 + totals = [] + for _ in range(n_trials): + out = dropout_weights(log_weights, 0.3) + linear_out = torch.exp(out) + assert torch.isfinite(linear_out).all() + totals.append(linear_out.sum().item()) + + expected_sum = linear_weights.sum() + observed_mean = float(np.mean(totals)) + # Monte Carlo tolerance: standard error scales ~sqrt(p*(1-p)/n). + assert abs(observed_mean - expected_sum) / expected_sum < 0.02, ( + f"Inverted dropout should preserve the linear-space sum in " + f"expectation; got mean {observed_mean:.1f} vs expected " + f"{expected_sum:.1f}." + ) + + +def test_dropout_drops_approximately_p_fraction() -> None: + """At p=0.5, roughly half the linear-space outputs must be zero.""" + weights = torch.log(torch.ones(10_000) * 42.0) + torch.manual_seed(1) + result = dropout_weights(weights, 0.5) + fraction_zero = (torch.exp(result) == 0).float().mean().item() + assert 0.45 < fraction_zero < 0.55 + + +def test_dropout_rejects_out_of_range() -> None: + """Out-of-range dropout probabilities must raise explicitly.""" + weights = torch.log(torch.tensor([1.0, 2.0])) + with pytest.raises(ValueError): + dropout_weights(weights, 1.5) + with pytest.raises(ValueError): + dropout_weights(weights, -0.1) + + +def test_reweight_runs_with_realistic_scale_dropout() -> None: + """End-to-end: training with dropout_rate > 0 on realistic-scale + weights must not inject NaN/Inf into the loss. + """ + + def estimate_function(w: torch.Tensor) -> torch.Tensor: + return w.sum().unsqueeze(0) + + rng = np.random.default_rng(0) + original_weights = rng.uniform(100.0, 5000.0, size=200) + targets = np.array([original_weights.sum() * 1.05]) + logger = logging.getLogger("test_dropout_regression") + + torch.manual_seed(0) + final_weights, _sparse, _df = reweight( + original_weights=original_weights, + estimate_function=estimate_function, + targets_array=targets, + target_names=np.array(["total"]), + l0_lambda=0.0, + init_mean=0.999, + temperature=0.5, + regularize_with_l0=False, + sparse_learning_rate=0.2, + dropout_rate=0.3, + epochs=5, + noise_level=0.0, + learning_rate=1e-3, + logger=logger, + ) + assert np.all(np.isfinite(final_weights))