Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/microplex_us/pipelines/pe_l0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

from __future__ import annotations

import inspect
import os
import sys
from collections.abc import Callable
from os import PathLike
from pathlib import Path
from typing import Any, Self

import numpy as np
Expand All @@ -15,6 +19,75 @@
)
from scipy import sparse as sp

_PE_US_DATA_REPO_ENV = "MICROPLEX_US_POLICYENGINE_US_DATA_REPO"


def make_policyengine_us_data_fit_l0_weights_fn(
repo_root: str | PathLike[str] | None = None,
) -> Callable[..., np.ndarray]:
"""Return a lazy wrapper around PE-US-data's L0 weight optimizer.

Microplex passes adapter-specific diagnostics such as ``target_names`` and
``initial_weights``. The incumbent PE-US-data function accepts a narrower
signature, so this wrapper keeps the public hook stable while delegating
only supported arguments.
"""

def _fit_l0_weights(**kwargs: Any) -> np.ndarray:
fit_l0_weights = _load_policyengine_us_data_fit_l0_weights(repo_root)
accepted_parameters = set(inspect.signature(fit_l0_weights).parameters)
call_kwargs = {
key: value for key, value in kwargs.items() if key in accepted_parameters
}
return np.asarray(fit_l0_weights(**call_kwargs), dtype=float)

return _fit_l0_weights


def _load_policyengine_us_data_fit_l0_weights(
repo_root: str | PathLike[str] | None = None,
) -> Callable[..., np.ndarray]:
resolved_repo = _resolve_policyengine_us_data_repo_root(repo_root)
inserted_path: str | None = None
if resolved_repo is not None:
inserted_path = str(resolved_repo)
if inserted_path not in sys.path:
sys.path.insert(0, inserted_path)
try:
from policyengine_us_data.calibration.unified_calibration import (
fit_l0_weights,
)
except ImportError as exc:
location = (
f" at {resolved_repo}"
if resolved_repo is not None
else " from the active Python environment"
)
raise RuntimeError(
"The pe_l0 backend requires policyengine-us-data's "
f"fit_l0_weights{location}. Set "
f"{_PE_US_DATA_REPO_ENV} or install policyengine-us-data."
) from exc
finally:
if inserted_path is not None and sys.path[0] == inserted_path:
sys.path.pop(0)
return fit_l0_weights


def _resolve_policyengine_us_data_repo_root(
repo_root: str | PathLike[str] | None = None,
) -> Path | None:
candidate = repo_root or os.environ.get(_PE_US_DATA_REPO_ENV)
if candidate is None:
return None
resolved = Path(candidate).expanduser().resolve()
if not (resolved / "policyengine_us_data").exists():
raise RuntimeError(
"policyengine-us-data repo root does not contain "
f"policyengine_us_data/: {resolved}"
)
return resolved


class PolicyEngineL0Calibrator:
"""Legacy L0 adapter for explicit experiments behind the Microplex interface."""
Expand Down
6 changes: 5 additions & 1 deletion src/microplex_us/pipelines/us.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@
ColumnwiseQRFDonorImputer,
RegimeAwareDonorImputer,
)
from microplex_us.pipelines.pe_l0 import PolicyEngineL0Calibrator
from microplex_us.pipelines.pe_l0 import (
PolicyEngineL0Calibrator,
make_policyengine_us_data_fit_l0_weights_fn,
)
from microplex_us.pipelines.pe_native_optimization import (
optimize_policyengine_us_native_loss_dataset,
)
Expand Down Expand Up @@ -3244,6 +3247,7 @@ def _build_weight_calibrator(
epochs=max(self.config.calibration_max_iter, 100),
device=self.config.device,
tol=self.config.calibration_tol,
fit_l0_weights_fn=make_policyengine_us_data_fit_l0_weights_fn(),
)
if self.config.calibration_backend == "microcalibrate":
from microplex_us.calibration import (
Expand Down
7 changes: 4 additions & 3 deletions tests/calibration/test_us_pipeline_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def test_backend_dispatch_fit_transform_end_to_end() -> None:
# Constraint: weighted count of households with income > 80k should be 1.4x current.
mask = (data["income"] > 80_000).to_numpy(dtype=float)
target = 1.4 * float(mask.sum())
constraint = LinearConstraint(
name="above_80k", coefficients=mask, target=target
)
constraint = LinearConstraint(name="above_80k", coefficients=mask, target=target)

result = calibrator.fit_transform(
data,
Expand Down Expand Up @@ -101,6 +99,9 @@ def test_pe_l0_deferred_stage_disables_sparsity_penalty() -> None:
assert stage1.lambda_l0 == pytest.approx(1e-4)
assert stage2.lambda_l0 == 0.0
assert stage3.lambda_l0 == 0.0
assert stage1.fit_l0_weights_fn is not None
assert stage2.fit_l0_weights_fn is not None
assert stage3.fit_l0_weights_fn is not None


def test_hardconcrete_deferred_stage_disables_sparsity_penalty() -> None:
Expand Down
75 changes: 74 additions & 1 deletion tests/pipelines/test_pe_l0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

from __future__ import annotations

import sys
import types

import numpy as np
import pandas as pd
import pytest
from microplex.calibration import LinearConstraint

from microplex_us.pipelines.pe_l0 import PolicyEngineL0Calibrator
from microplex_us.pipelines.pe_l0 import (
PolicyEngineL0Calibrator,
make_policyengine_us_data_fit_l0_weights_fn,
)


def _install_fake_policyengine_l0(weights: np.ndarray):
Expand Down Expand Up @@ -127,3 +133,70 @@ def test_policyengine_l0_requires_explicit_fit_function_for_nonzero_l0():
weight_col="weight",
linear_constraints=constraints,
)


def test_policyengine_l0_can_wrap_policyengine_us_data_fit_function(monkeypatch):
calls: dict[str, object] = {}

def fake_policyengine_fit_l0_weights(
*,
X_sparse,
targets,
lambda_l0,
epochs=100,
device="cpu",
verbose_freq=None,
target_groups=None,
):
kwargs = {
"X_sparse": X_sparse,
"targets": targets,
"lambda_l0": lambda_l0,
"epochs": epochs,
"device": device,
"verbose_freq": verbose_freq,
"target_groups": target_groups,
}
calls.update(kwargs)
return np.array([4.0, 5.0])

package = types.ModuleType("policyengine_us_data")
package.__path__ = []
calibration_package = types.ModuleType("policyengine_us_data.calibration")
calibration_package.__path__ = []
unified = types.ModuleType("policyengine_us_data.calibration.unified_calibration")
unified.fit_l0_weights = fake_policyengine_fit_l0_weights
monkeypatch.setitem(sys.modules, "policyengine_us_data", package)
monkeypatch.setitem(
sys.modules,
"policyengine_us_data.calibration",
calibration_package,
)
monkeypatch.setitem(
sys.modules,
"policyengine_us_data.calibration.unified_calibration",
unified,
)

fit_l0_weights = make_policyengine_us_data_fit_l0_weights_fn()
result = fit_l0_weights(
X_sparse="matrix",
targets=np.array([1.0]),
lambda_l0=1e-8,
epochs=2,
device="cpu",
verbose_freq=1,
initial_weights=np.array([1.0, 1.0]),
target_names=["target"],
)

assert result.tolist() == [4.0, 5.0]
assert calls["X_sparse"] == "matrix"
np.testing.assert_array_equal(calls["targets"], np.array([1.0]))
assert calls["lambda_l0"] == pytest.approx(1e-8)
assert calls["epochs"] == 2
assert calls["device"] == "cpu"
assert calls["verbose_freq"] == 1
assert calls["target_groups"] is None
assert "initial_weights" not in calls
assert "target_names" not in calls
Loading