diff --git a/src/microplex_us/pipelines/pe_l0.py b/src/microplex_us/pipelines/pe_l0.py index d24ba7b..761de42 100644 --- a/src/microplex_us/pipelines/pe_l0.py +++ b/src/microplex_us/pipelines/pe_l0.py @@ -2,8 +2,12 @@ from __future__ import annotations +import inspect +import os +import sys from collections.abc import Callable from os import PathLike +from pathlib import Path from typing import Any, Self import numpy as np @@ -15,6 +19,75 @@ ) from scipy import sparse as sp +_PE_US_DATA_REPO_ENV = "MICROPLEX_US_POLICYENGINE_US_DATA_REPO" + + +def make_policyengine_us_data_fit_l0_weights_fn( + repo_root: str | PathLike[str] | None = None, +) -> Callable[..., np.ndarray]: + """Return a lazy wrapper around PE-US-data's L0 weight optimizer. + + Microplex passes adapter-specific diagnostics such as ``target_names`` and + ``initial_weights``. The incumbent PE-US-data function accepts a narrower + signature, so this wrapper keeps the public hook stable while delegating + only supported arguments. + """ + + def _fit_l0_weights(**kwargs: Any) -> np.ndarray: + fit_l0_weights = _load_policyengine_us_data_fit_l0_weights(repo_root) + accepted_parameters = set(inspect.signature(fit_l0_weights).parameters) + call_kwargs = { + key: value for key, value in kwargs.items() if key in accepted_parameters + } + return np.asarray(fit_l0_weights(**call_kwargs), dtype=float) + + return _fit_l0_weights + + +def _load_policyengine_us_data_fit_l0_weights( + repo_root: str | PathLike[str] | None = None, +) -> Callable[..., np.ndarray]: + resolved_repo = _resolve_policyengine_us_data_repo_root(repo_root) + inserted_path: str | None = None + if resolved_repo is not None: + inserted_path = str(resolved_repo) + if inserted_path not in sys.path: + sys.path.insert(0, inserted_path) + try: + from policyengine_us_data.calibration.unified_calibration import ( + fit_l0_weights, + ) + except ImportError as exc: + location = ( + f" at {resolved_repo}" + if resolved_repo is not None + else " from the active Python environment" + ) + raise RuntimeError( + "The pe_l0 backend requires policyengine-us-data's " + f"fit_l0_weights{location}. Set " + f"{_PE_US_DATA_REPO_ENV} or install policyengine-us-data." + ) from exc + finally: + if inserted_path is not None and sys.path[0] == inserted_path: + sys.path.pop(0) + return fit_l0_weights + + +def _resolve_policyengine_us_data_repo_root( + repo_root: str | PathLike[str] | None = None, +) -> Path | None: + candidate = repo_root or os.environ.get(_PE_US_DATA_REPO_ENV) + if candidate is None: + return None + resolved = Path(candidate).expanduser().resolve() + if not (resolved / "policyengine_us_data").exists(): + raise RuntimeError( + "policyengine-us-data repo root does not contain " + f"policyengine_us_data/: {resolved}" + ) + return resolved + class PolicyEngineL0Calibrator: """Legacy L0 adapter for explicit experiments behind the Microplex interface.""" diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index 7d3f959..6b5651e 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -69,7 +69,10 @@ ColumnwiseQRFDonorImputer, RegimeAwareDonorImputer, ) -from microplex_us.pipelines.pe_l0 import PolicyEngineL0Calibrator +from microplex_us.pipelines.pe_l0 import ( + PolicyEngineL0Calibrator, + make_policyengine_us_data_fit_l0_weights_fn, +) from microplex_us.pipelines.pe_native_optimization import ( optimize_policyengine_us_native_loss_dataset, ) @@ -3244,6 +3247,7 @@ def _build_weight_calibrator( epochs=max(self.config.calibration_max_iter, 100), device=self.config.device, tol=self.config.calibration_tol, + fit_l0_weights_fn=make_policyengine_us_data_fit_l0_weights_fn(), ) if self.config.calibration_backend == "microcalibrate": from microplex_us.calibration import ( diff --git a/tests/calibration/test_us_pipeline_dispatch.py b/tests/calibration/test_us_pipeline_dispatch.py index d7fd33d..3e6da49 100644 --- a/tests/calibration/test_us_pipeline_dispatch.py +++ b/tests/calibration/test_us_pipeline_dispatch.py @@ -51,9 +51,7 @@ def test_backend_dispatch_fit_transform_end_to_end() -> None: # Constraint: weighted count of households with income > 80k should be 1.4x current. mask = (data["income"] > 80_000).to_numpy(dtype=float) target = 1.4 * float(mask.sum()) - constraint = LinearConstraint( - name="above_80k", coefficients=mask, target=target - ) + constraint = LinearConstraint(name="above_80k", coefficients=mask, target=target) result = calibrator.fit_transform( data, @@ -101,6 +99,9 @@ def test_pe_l0_deferred_stage_disables_sparsity_penalty() -> None: assert stage1.lambda_l0 == pytest.approx(1e-4) assert stage2.lambda_l0 == 0.0 assert stage3.lambda_l0 == 0.0 + assert stage1.fit_l0_weights_fn is not None + assert stage2.fit_l0_weights_fn is not None + assert stage3.fit_l0_weights_fn is not None def test_hardconcrete_deferred_stage_disables_sparsity_penalty() -> None: diff --git a/tests/pipelines/test_pe_l0.py b/tests/pipelines/test_pe_l0.py index 3696129..c6cdb7a 100644 --- a/tests/pipelines/test_pe_l0.py +++ b/tests/pipelines/test_pe_l0.py @@ -2,12 +2,18 @@ from __future__ import annotations +import sys +import types + import numpy as np import pandas as pd import pytest from microplex.calibration import LinearConstraint -from microplex_us.pipelines.pe_l0 import PolicyEngineL0Calibrator +from microplex_us.pipelines.pe_l0 import ( + PolicyEngineL0Calibrator, + make_policyengine_us_data_fit_l0_weights_fn, +) def _install_fake_policyengine_l0(weights: np.ndarray): @@ -127,3 +133,70 @@ def test_policyengine_l0_requires_explicit_fit_function_for_nonzero_l0(): weight_col="weight", linear_constraints=constraints, ) + + +def test_policyengine_l0_can_wrap_policyengine_us_data_fit_function(monkeypatch): + calls: dict[str, object] = {} + + def fake_policyengine_fit_l0_weights( + *, + X_sparse, + targets, + lambda_l0, + epochs=100, + device="cpu", + verbose_freq=None, + target_groups=None, + ): + kwargs = { + "X_sparse": X_sparse, + "targets": targets, + "lambda_l0": lambda_l0, + "epochs": epochs, + "device": device, + "verbose_freq": verbose_freq, + "target_groups": target_groups, + } + calls.update(kwargs) + return np.array([4.0, 5.0]) + + package = types.ModuleType("policyengine_us_data") + package.__path__ = [] + calibration_package = types.ModuleType("policyengine_us_data.calibration") + calibration_package.__path__ = [] + unified = types.ModuleType("policyengine_us_data.calibration.unified_calibration") + unified.fit_l0_weights = fake_policyengine_fit_l0_weights + monkeypatch.setitem(sys.modules, "policyengine_us_data", package) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration", + calibration_package, + ) + monkeypatch.setitem( + sys.modules, + "policyengine_us_data.calibration.unified_calibration", + unified, + ) + + fit_l0_weights = make_policyengine_us_data_fit_l0_weights_fn() + result = fit_l0_weights( + X_sparse="matrix", + targets=np.array([1.0]), + lambda_l0=1e-8, + epochs=2, + device="cpu", + verbose_freq=1, + initial_weights=np.array([1.0, 1.0]), + target_names=["target"], + ) + + assert result.tolist() == [4.0, 5.0] + assert calls["X_sparse"] == "matrix" + np.testing.assert_array_equal(calls["targets"], np.array([1.0])) + assert calls["lambda_l0"] == pytest.approx(1e-8) + assert calls["epochs"] == 2 + assert calls["device"] == "cpu" + assert calls["verbose_freq"] == 1 + assert calls["target_groups"] is None + assert "initial_weights" not in calls + assert "target_names" not in calls