Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/192.breaking
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Removed caller-supplied rare-class thresholds from `ZeroInflatedImputer`; regime detection is now based solely on observed negative, zero, and positive support in the training data.
1 change: 1 addition & 0 deletions changelog.d/192.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Capped `scikit-learn` below 1.9 while `quantile-forest` depends on the pre-1.9 sklearn tree extension API.
57 changes: 13 additions & 44 deletions microimpute/models/zero_inflated.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,23 @@
The wrapper is generic over the base imputer — ``QRF`` is the obvious
default, but ``MDN``, ``OLS``, or ``Matching`` all compose the same way.

Regime detection is parameterized by ``min_class_count`` and
``min_class_fraction``: a class with fewer observations than both
thresholds collapses into the closest adjacent regime. This avoids
fitting a full three-sign split on a variable whose negative tail is
five outlier rows — the cost-benefit flips toward the simpler
architecture.
Regime detection is based only on observed support. If the training data
contains negative, zero, and positive values, the imputer uses the
three-sign architecture. Callers do not provide sign/regime metadata.
"""

from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

import numpy as np
import pandas as pd
from pydantic import SkipValidation, validate_call
from pydantic import validate_call

from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG
from microimpute.models.imputer import (
Imputer,
ImputerResults,
_ConstantValueModel,
)
from microimpute.models.qrf import QRF

Expand Down Expand Up @@ -95,19 +90,14 @@ def _make_classifier(kind: str, seed: int):
def _detect_regime(
y: np.ndarray,
*,
min_class_count: int,
min_class_fraction: float,
zero_atol: float,
) -> str:
"""Classify the training distribution into one of seven regimes.

A class (neg/zero/pos) counts as present iff its count is at least
``min_class_count`` AND its fraction of total rows is at least
``min_class_fraction``. Below both thresholds, the class collapses
into its closest adjacent regime (minority negatives merge into
zero → ZI_POSITIVE; minority zeros merge into the majority sign;
etc.). This keeps the gate architecture stable in the presence of
measurement-error outliers.
A class (neg/zero/pos) counts as present when it appears at least
once in the training data. Sign support is inferred from donor data;
callers cannot force a variable to be positive-only, negative-only,
or signed.
"""
n = len(y)
if n == 0:
Expand All @@ -121,22 +111,12 @@ def _detect_regime(
n_pos = int(is_pos.sum())
n_neg = int(is_neg.sum())

# Apply both thresholds.
def _meaningful(count: int) -> bool:
return count >= min_class_count and (count / n) >= min_class_fraction

has_zero = _meaningful(n_zero)
has_pos = _meaningful(n_pos)
has_neg = _meaningful(n_neg)
has_zero = n_zero > 0
has_pos = n_pos > 0
has_neg = n_neg > 0

if not (has_zero or has_pos or has_neg):
# All three classes are below threshold. Pick the one with the
# largest raw count as a degenerate fallback.
counts = {"zero": n_zero, "pos": n_pos, "neg": n_neg}
majority = max(counts, key=counts.get)
if majority == "zero":
return REGIME_DEGENERATE_ZERO
return REGIME_POSITIVE_ONLY if majority == "pos" else REGIME_NEGATIVE_ONLY
return REGIME_DEGENERATE_ZERO

if has_pos and has_neg and has_zero:
return REGIME_THREE_SIGN
Expand All @@ -161,11 +141,6 @@ class ZeroInflatedImputer(Imputer):
regression step. Defaults to ``QRF``.
base_imputer_kwargs: Keyword arguments forwarded to the base
imputer constructor. ``{}`` by default.
min_class_count: Minimum raw count per class (neg/0/pos) for
that class to be considered present. Below this, the class
collapses into an adjacent regime. Defaults to 10.
min_class_fraction: Minimum fraction of total rows per class
for that class to be considered present. Defaults to 0.01.
zero_atol: Absolute tolerance for "equals zero" in the regime
detector. Defaults to 1e-6, matching the upstream
``_MultiSourceBase`` convention.
Expand All @@ -179,8 +154,6 @@ def __init__(
self,
base_imputer_class: Optional[Type[Imputer]] = None,
base_imputer_kwargs: Optional[Dict[str, Any]] = None,
min_class_count: int = 10,
min_class_fraction: float = 0.01,
zero_atol: float = 1e-6,
classifier_type: str = "hist_gb",
seed: Optional[int] = RANDOM_STATE,
Expand All @@ -189,8 +162,6 @@ def __init__(
super().__init__(seed=seed, log_level=log_level)
self.base_imputer_class = base_imputer_class or QRF
self.base_imputer_kwargs = dict(base_imputer_kwargs or {})
self.min_class_count = int(min_class_count)
self.min_class_fraction = float(min_class_fraction)
self.zero_atol = float(zero_atol)
self.classifier_type = classifier_type

Expand Down Expand Up @@ -267,8 +238,6 @@ def fit(
y = X_train[var].to_numpy(dtype=float, copy=False)
regime = _detect_regime(
y,
min_class_count=self.min_class_count,
min_class_fraction=self.min_class_fraction,
zero_atol=self.zero_atol,
)
self._regimes[var] = regime
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"numpy>=2.0.0,<3.0.0",
"pandas>=2.2.0,<4.0.0",
"plotly>=5.24.0,<7.0.0",
"scikit-learn>=1.7.0,<2.0.0",
"scikit-learn>=1.7.0,<1.9.0",
"scipy>=1.16.0,<2.0.0",
"requests>=2.32.0,<3.0.0",
"tqdm>=4.65.0,<5.0.0",
Expand Down
28 changes: 12 additions & 16 deletions tests/test_models/test_zero_inflated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@
2. Predictions respect the detected regime (no zero leaks, no
sign-interpolation between positive and negative regimes).
3. Fit/predict lifecycle matches the base `Imputer` contract.
4. Rare-class thresholds: tiny negative tails don't trigger a full
three-sign split unless above a configurable minimum.
4. Pure support detection: any observed negative, zero, or positive
support participates in regime selection.
"""

from __future__ import annotations

from typing import Dict, List

import numpy as np
import pandas as pd
import pytest

from microimpute.models.qrf import QRF

Expand Down Expand Up @@ -136,14 +133,12 @@ def test_constant_zero_is_degenerate(self) -> None:
imputer.fit(data, predictors=["age", "income_bin"], imputed_variables=["y"])
assert imputer.get_regime("y") == "DEGENERATE_ZERO"

def test_rare_negative_tail_stays_zi_positive(self) -> None:
"""If negative-class count is below min_class_count, treat as ZI_POSITIVE.
def test_rare_negative_tail_triggers_three_sign(self) -> None:
"""Any observed negative support participates in auto detection.

Capital gains example: 97% zero, 2.9% positive, 0.1% negative.
The negative mass is real but below the 10-sample threshold on
a 500-record fixture. Should NOT trigger three-sign; instead
collapses to ZI_POSITIVE with the few negatives discarded from
the base imputer's fit (and a warning).
The negative mass is real and should trigger THREE_SIGN without
caller-supplied support/sign metadata.
"""
from microimpute.models.zero_inflated import ZeroInflatedImputer

Expand All @@ -152,15 +147,16 @@ def test_rare_negative_tail_stays_zi_positive(self) -> None:
u = rng.random(n)
y = np.zeros(n)
pos_mask = u > 0.971
neg_mask = (u > 0.970) & (u <= 0.971)
y[pos_mask] = rng.exponential(100, size=pos_mask.sum())
y[neg_mask] = -rng.exponential(50, size=neg_mask.sum())
assert (y < 0).sum() < 10, "fixture precondition"
y[0] = -50.0
assert (y < 0).sum() == 1, "fixture precondition"
assert (y > 0).sum() > 0, "fixture precondition"
assert (y == 0).sum() > 0, "fixture precondition"

data = _deterministic_frame(n, y)
imputer = ZeroInflatedImputer(base_imputer_class=QRF, min_class_count=10)
imputer = ZeroInflatedImputer(base_imputer_class=QRF)
imputer.fit(data, predictors=["age", "income_bin"], imputed_variables=["y"])
assert imputer.get_regime("y") == "ZI_POSITIVE"
assert imputer.get_regime("y") == "THREE_SIGN"


class TestPredictionsRespectRegime:
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading