diff --git a/changelog.d/192.breaking b/changelog.d/192.breaking new file mode 100644 index 0000000..b9ae6e9 --- /dev/null +++ b/changelog.d/192.breaking @@ -0,0 +1 @@ +Removed caller-supplied rare-class thresholds from `ZeroInflatedImputer`; regime detection is now based solely on observed negative, zero, and positive support in the training data. diff --git a/changelog.d/192.fixed b/changelog.d/192.fixed new file mode 100644 index 0000000..35df179 --- /dev/null +++ b/changelog.d/192.fixed @@ -0,0 +1 @@ +Capped `scikit-learn` below 1.9 while `quantile-forest` depends on the pre-1.9 sklearn tree extension API. diff --git a/microimpute/models/zero_inflated.py b/microimpute/models/zero_inflated.py index 5ffb21a..8858446 100644 --- a/microimpute/models/zero_inflated.py +++ b/microimpute/models/zero_inflated.py @@ -35,28 +35,23 @@ The wrapper is generic over the base imputer — ``QRF`` is the obvious default, but ``MDN``, ``OLS``, or ``Matching`` all compose the same way. -Regime detection is parameterized by ``min_class_count`` and -``min_class_fraction``: a class with fewer observations than both -thresholds collapses into the closest adjacent regime. This avoids -fitting a full three-sign split on a variable whose negative tail is -five outlier rows — the cost-benefit flips toward the simpler -architecture. +Regime detection is based only on observed support. If the training data +contains negative, zero, and positive values, the imputer uses the +three-sign architecture. Callers do not provide sign/regime metadata. """ from __future__ import annotations -import logging -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import numpy as np import pandas as pd -from pydantic import SkipValidation, validate_call +from pydantic import validate_call from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG from microimpute.models.imputer import ( Imputer, ImputerResults, - _ConstantValueModel, ) from microimpute.models.qrf import QRF @@ -95,19 +90,14 @@ def _make_classifier(kind: str, seed: int): def _detect_regime( y: np.ndarray, *, - min_class_count: int, - min_class_fraction: float, zero_atol: float, ) -> str: """Classify the training distribution into one of seven regimes. - A class (neg/zero/pos) counts as present iff its count is at least - ``min_class_count`` AND its fraction of total rows is at least - ``min_class_fraction``. Below both thresholds, the class collapses - into its closest adjacent regime (minority negatives merge into - zero → ZI_POSITIVE; minority zeros merge into the majority sign; - etc.). This keeps the gate architecture stable in the presence of - measurement-error outliers. + A class (neg/zero/pos) counts as present when it appears at least + once in the training data. Sign support is inferred from donor data; + callers cannot force a variable to be positive-only, negative-only, + or signed. """ n = len(y) if n == 0: @@ -121,22 +111,12 @@ def _detect_regime( n_pos = int(is_pos.sum()) n_neg = int(is_neg.sum()) - # Apply both thresholds. - def _meaningful(count: int) -> bool: - return count >= min_class_count and (count / n) >= min_class_fraction - - has_zero = _meaningful(n_zero) - has_pos = _meaningful(n_pos) - has_neg = _meaningful(n_neg) + has_zero = n_zero > 0 + has_pos = n_pos > 0 + has_neg = n_neg > 0 if not (has_zero or has_pos or has_neg): - # All three classes are below threshold. Pick the one with the - # largest raw count as a degenerate fallback. - counts = {"zero": n_zero, "pos": n_pos, "neg": n_neg} - majority = max(counts, key=counts.get) - if majority == "zero": - return REGIME_DEGENERATE_ZERO - return REGIME_POSITIVE_ONLY if majority == "pos" else REGIME_NEGATIVE_ONLY + return REGIME_DEGENERATE_ZERO if has_pos and has_neg and has_zero: return REGIME_THREE_SIGN @@ -161,11 +141,6 @@ class ZeroInflatedImputer(Imputer): regression step. Defaults to ``QRF``. base_imputer_kwargs: Keyword arguments forwarded to the base imputer constructor. ``{}`` by default. - min_class_count: Minimum raw count per class (neg/0/pos) for - that class to be considered present. Below this, the class - collapses into an adjacent regime. Defaults to 10. - min_class_fraction: Minimum fraction of total rows per class - for that class to be considered present. Defaults to 0.01. zero_atol: Absolute tolerance for "equals zero" in the regime detector. Defaults to 1e-6, matching the upstream ``_MultiSourceBase`` convention. @@ -179,8 +154,6 @@ def __init__( self, base_imputer_class: Optional[Type[Imputer]] = None, base_imputer_kwargs: Optional[Dict[str, Any]] = None, - min_class_count: int = 10, - min_class_fraction: float = 0.01, zero_atol: float = 1e-6, classifier_type: str = "hist_gb", seed: Optional[int] = RANDOM_STATE, @@ -189,8 +162,6 @@ def __init__( super().__init__(seed=seed, log_level=log_level) self.base_imputer_class = base_imputer_class or QRF self.base_imputer_kwargs = dict(base_imputer_kwargs or {}) - self.min_class_count = int(min_class_count) - self.min_class_fraction = float(min_class_fraction) self.zero_atol = float(zero_atol) self.classifier_type = classifier_type @@ -267,8 +238,6 @@ def fit( y = X_train[var].to_numpy(dtype=float, copy=False) regime = _detect_regime( y, - min_class_count=self.min_class_count, - min_class_fraction=self.min_class_fraction, zero_atol=self.zero_atol, ) self._regimes[var] = regime diff --git a/pyproject.toml b/pyproject.toml index d2635d2..01fa51b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "numpy>=2.0.0,<3.0.0", "pandas>=2.2.0,<4.0.0", "plotly>=5.24.0,<7.0.0", - "scikit-learn>=1.7.0,<2.0.0", + "scikit-learn>=1.7.0,<1.9.0", "scipy>=1.16.0,<2.0.0", "requests>=2.32.0,<3.0.0", "tqdm>=4.65.0,<5.0.0", diff --git a/tests/test_models/test_zero_inflated.py b/tests/test_models/test_zero_inflated.py index 3a43a29..c75195e 100644 --- a/tests/test_models/test_zero_inflated.py +++ b/tests/test_models/test_zero_inflated.py @@ -25,17 +25,14 @@ 2. Predictions respect the detected regime (no zero leaks, no sign-interpolation between positive and negative regimes). 3. Fit/predict lifecycle matches the base `Imputer` contract. -4. Rare-class thresholds: tiny negative tails don't trigger a full - three-sign split unless above a configurable minimum. +4. Pure support detection: any observed negative, zero, or positive + support participates in regime selection. """ from __future__ import annotations -from typing import Dict, List - import numpy as np import pandas as pd -import pytest from microimpute.models.qrf import QRF @@ -136,14 +133,12 @@ def test_constant_zero_is_degenerate(self) -> None: imputer.fit(data, predictors=["age", "income_bin"], imputed_variables=["y"]) assert imputer.get_regime("y") == "DEGENERATE_ZERO" - def test_rare_negative_tail_stays_zi_positive(self) -> None: - """If negative-class count is below min_class_count, treat as ZI_POSITIVE. + def test_rare_negative_tail_triggers_three_sign(self) -> None: + """Any observed negative support participates in auto detection. Capital gains example: 97% zero, 2.9% positive, 0.1% negative. - The negative mass is real but below the 10-sample threshold on - a 500-record fixture. Should NOT trigger three-sign; instead - collapses to ZI_POSITIVE with the few negatives discarded from - the base imputer's fit (and a warning). + The negative mass is real and should trigger THREE_SIGN without + caller-supplied support/sign metadata. """ from microimpute.models.zero_inflated import ZeroInflatedImputer @@ -152,15 +147,16 @@ def test_rare_negative_tail_stays_zi_positive(self) -> None: u = rng.random(n) y = np.zeros(n) pos_mask = u > 0.971 - neg_mask = (u > 0.970) & (u <= 0.971) y[pos_mask] = rng.exponential(100, size=pos_mask.sum()) - y[neg_mask] = -rng.exponential(50, size=neg_mask.sum()) - assert (y < 0).sum() < 10, "fixture precondition" + y[0] = -50.0 + assert (y < 0).sum() == 1, "fixture precondition" + assert (y > 0).sum() > 0, "fixture precondition" + assert (y == 0).sum() > 0, "fixture precondition" data = _deterministic_frame(n, y) - imputer = ZeroInflatedImputer(base_imputer_class=QRF, min_class_count=10) + imputer = ZeroInflatedImputer(base_imputer_class=QRF) imputer.fit(data, predictors=["age", "income_bin"], imputed_variables=["y"]) - assert imputer.get_regime("y") == "ZI_POSITIVE" + assert imputer.get_regime("y") == "THREE_SIGN" class TestPredictionsRespectRegime: diff --git a/uv.lock b/uv.lock index 7dc427c..259e65b 100644 --- a/uv.lock +++ b/uv.lock @@ -1117,7 +1117,7 @@ requires-dist = [ { name = "requests", specifier = ">=2.32.0,<3.0.0" }, { name = "rpy2", marker = "extra == 'matching'", specifier = ">=3.5.0,<4.0.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.9.0" }, - { name = "scikit-learn", specifier = ">=1.7.0,<2.0.0" }, + { name = "scikit-learn", specifier = ">=1.7.0,<1.9.0" }, { name = "scipy", specifier = ">=1.16.0,<2.0.0" }, { name = "statsmodels", specifier = ">=0.14.5,<0.16.0" }, { name = "torch", marker = "extra == 'mdn'", specifier = ">=2.0.0" },