Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fix-dropout-log-space.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reimplement dropout_weights() correctly for log-space weights. The previous implementation set masked log-entries to 0 (exp(0) = 1, not dropped) and divided by the sum of logs, which on realistic weight scales could cross zero and inject Inf/NaN into training. The new implementation applies standard inverted dropout: dropped entries go to -inf in log space (and therefore zero in linear space), survivors are scaled by 1/(1-p) so the expected linear-space sum is preserved.
55 changes: 36 additions & 19 deletions src/microcalibrate/reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,42 @@
from .utils.metrics import loss, pct_close


def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor:
"""Apply inverted dropout to weights held in log space.

``weights`` represents log(w); downstream code computes
``torch.exp(weights_)`` to recover linear-space weights. Dropping
an entry therefore means sending its log to ``-inf`` so that
``exp`` returns 0. Surviving entries are scaled by ``1/(1-p)`` in
linear space (equivalently, ``-log(1-p)`` added in log space) so
the expected linear-space sum is preserved, matching standard
inverted dropout semantics.

Args:
weights (torch.Tensor): Current weights in log space.
p (float): Probability of dropping each weight, in [0, 1].

Returns:
torch.Tensor: Weights in log space after applying dropout.
"""
if p == 0:
return weights
if p < 0 or p > 1:
raise ValueError(f"dropout_rate must be in [0, 1]; got {p}.")
if p == 1:
# Everything is dropped: zero all linear-space weights. The
# result has no gradient path back to ``weights`` because every
# entry is a constant -inf; callers must not rely on training
# under full dropout.
return torch.full_like(weights, float("-inf"))
# ``survive_mask`` is True where an entry SURVIVES.
survive_mask = torch.rand_like(weights) >= p
neg_inf = torch.full_like(weights, float("-inf"))
scale = -float(np.log1p(-p)) # == log(1/(1-p))
scaled = weights + scale
return torch.where(survive_mask, scaled, neg_inf)


def reweight(
original_weights: np.ndarray,
estimate_function: Callable[[Tensor], Tensor],
Expand Down Expand Up @@ -91,25 +127,6 @@ def reweight(
f"std: {torch.exp(weights).std():.4f}"
)

def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor:
"""Apply dropout to the weights.

Args:
weights (torch.Tensor): Current weights in log space.
p (float): Probability of dropping weights.

Returns:
torch.Tensor: Weights after applying dropout.
"""
if p == 0:
return weights
total_weight = weights.sum()
mask = torch.rand_like(weights) < p
masked_weights = weights.clone()
masked_weights[mask] = 0
masked_weights = masked_weights / masked_weights.sum() * total_weight
return masked_weights

optimizer = torch.optim.Adam([weights], lr=learning_rate)

iterator = tqdm(range(epochs), desc="Reweighting progress", unit="epoch")
Expand Down
115 changes: 115 additions & 0 deletions tests/test_dropout_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Regression tests for the log-space dropout bug (finding #2).

The previous implementation set masked entries in log space to ``0``,
which corresponds to ``exp(0) = 1`` in linear space -- the opposite of
dropping them. It then divided masked log-weights by their sum, which
is not a meaningful normalisation on logs and could cross zero,
producing ``inf`` / ``NaN`` on realistic weight scales.
"""

import logging

import numpy as np
import pytest
import torch

from microcalibrate.reweight import dropout_weights, reweight


def test_dropout_p_zero_is_identity() -> None:
"""p=0 must return the input tensor unchanged (no dropout)."""
weights = torch.log(torch.tensor([10.0, 100.0, 1000.0]))
result = dropout_weights(weights, 0.0)
assert torch.equal(result, weights)


def test_dropout_p_one_zeroes_all_linear_weights() -> None:
"""p=1 must zero every linear-space weight."""
weights = torch.log(torch.tensor([10.0, 100.0, 1000.0]))
result = dropout_weights(weights, 1.0)
linear = torch.exp(result)
assert torch.all(linear == 0)


def test_dropout_preserves_sum_in_expectation_on_realistic_scale() -> None:
"""On realistic (non-unit) weights dropout must not produce NaN/Inf
and the expected linear-space sum must be preserved.

Regression: on realistic survey weights (hundreds to thousands),
``log(w)`` is ~6-8 and the previous normalisation step could cross
zero, yielding ``inf`` or ``NaN``. With inverted dropout, the
expected linear-space sum of the output equals the linear-space
sum of the input.
"""
rng = np.random.default_rng(7)
linear_weights = rng.uniform(100.0, 5000.0, size=500)
log_weights = torch.tensor(np.log(linear_weights), dtype=torch.float32)

torch.manual_seed(0)
n_trials = 200
totals = []
for _ in range(n_trials):
out = dropout_weights(log_weights, 0.3)
linear_out = torch.exp(out)
assert torch.isfinite(linear_out).all()
totals.append(linear_out.sum().item())

expected_sum = linear_weights.sum()
observed_mean = float(np.mean(totals))
# Monte Carlo tolerance: standard error scales ~sqrt(p*(1-p)/n).
assert abs(observed_mean - expected_sum) / expected_sum < 0.02, (
f"Inverted dropout should preserve the linear-space sum in "
f"expectation; got mean {observed_mean:.1f} vs expected "
f"{expected_sum:.1f}."
)


def test_dropout_drops_approximately_p_fraction() -> None:
"""At p=0.5, roughly half the linear-space outputs must be zero."""
weights = torch.log(torch.ones(10_000) * 42.0)
torch.manual_seed(1)
result = dropout_weights(weights, 0.5)
fraction_zero = (torch.exp(result) == 0).float().mean().item()
assert 0.45 < fraction_zero < 0.55


def test_dropout_rejects_out_of_range() -> None:
"""Out-of-range dropout probabilities must raise explicitly."""
weights = torch.log(torch.tensor([1.0, 2.0]))
with pytest.raises(ValueError):
dropout_weights(weights, 1.5)
with pytest.raises(ValueError):
dropout_weights(weights, -0.1)


def test_reweight_runs_with_realistic_scale_dropout() -> None:
"""End-to-end: training with dropout_rate > 0 on realistic-scale
weights must not inject NaN/Inf into the loss.
"""

def estimate_function(w: torch.Tensor) -> torch.Tensor:
return w.sum().unsqueeze(0)

rng = np.random.default_rng(0)
original_weights = rng.uniform(100.0, 5000.0, size=200)
targets = np.array([original_weights.sum() * 1.05])
logger = logging.getLogger("test_dropout_regression")

torch.manual_seed(0)
final_weights, _sparse, _df = reweight(
original_weights=original_weights,
estimate_function=estimate_function,
targets_array=targets,
target_names=np.array(["total"]),
l0_lambda=0.0,
init_mean=0.999,
temperature=0.5,
regularize_with_l0=False,
sparse_learning_rate=0.2,
dropout_rate=0.3,
epochs=5,
noise_level=0.0,
learning_rate=1e-3,
logger=logger,
)
assert np.all(np.isfinite(final_weights))
Loading