diff --git a/src/microplex_us/__init__.py b/src/microplex_us/__init__.py index ada081a..ec61541 100644 --- a/src/microplex_us/__init__.py +++ b/src/microplex_us/__init__.py @@ -5,52 +5,44 @@ from importlib import import_module from typing import Any -from microplex.targets import TargetSet, TargetSpec - -from microplex_us.calibration_harness import ( - CalibrationHarness, - CalibrationResult, - run_pe_parity_suite, +_CALIBRATION_HARNESS_EXPORTS = ( + "CalibrationHarness", + "CalibrationResult", + "run_pe_parity_suite", ) -from microplex_us.cps_synthetic import ( - CPSSummaryStats, - CPSSyntheticGenerator, - validate_synthetic, +_CPS_SYNTHETIC_EXPORTS = ( + "CPSSummaryStats", + "CPSSyntheticGenerator", + "validate_synthetic", ) -from microplex_us.data import ( - create_sample_data, - get_data_info, - load_cps_asec, - load_cps_for_synthesis, +_DATA_EXPORTS = ( + "create_sample_data", + "get_data_info", + "load_cps_asec", + "load_cps_for_synthesis", ) - -try: - from microplex_us.geography import ( - BLOCK_LEN, - COUNTY_LEN, - STATE_LEN, - TRACT_LEN, - BlockGeography, - derive_geographies, - load_block_probabilities, - normalize_us_state_fips, - ) -except ImportError: - BLOCK_LEN = None - COUNTY_LEN = None - STATE_LEN = None - TRACT_LEN = None - BlockGeography = None - derive_geographies = None - load_block_probabilities = None - normalize_us_state_fips = None -from microplex_us.hierarchical import prepare_cps_for_hierarchical -from microplex_us.pe_targets import ( - PETargets, - create_calibration_targets, - get_pe_targets, +_GEOGRAPHY_EXPORTS = ( + "BLOCK_LEN", + "COUNTY_LEN", + "STATE_LEN", + "TRACT_LEN", + "BlockGeography", + "derive_geographies", + "load_block_probabilities", + "normalize_us_state_fips", +) +_HIERARCHICAL_EXPORTS = ( + "prepare_cps_for_hierarchical", +) +_MICROPLEX_TARGET_EXPORTS = ( + "TargetSet", + "TargetSpec", +) +_PE_TARGETS_EXPORTS = ( + "PETargets", + "create_calibration_targets", + "get_pe_targets", ) - _PIPELINE_EXPORTS = ( "DEFAULT_ATOMIC_AGE_BINS", "DEFAULT_ATOMIC_AGE_LABELS", @@ -180,13 +172,13 @@ "SourceVariablePolicySpec", "resolve_source_variable_capabilities", ) -from microplex_us.target_registry import ( - TargetCategory, - TargetGroup, - TargetLevel, - TargetRegistry, - get_registry, - print_registry_summary, +_TARGET_REGISTRY_EXPORTS = ( + "TargetCategory", + "TargetGroup", + "TargetLevel", + "TargetRegistry", + "get_registry", + "print_registry_summary", ) _TARGETS_EXPORTS = ( @@ -194,46 +186,40 @@ "policyengine_db_target_to_canonical_spec", "policyengine_db_targets_to_canonical_set", ) -from microplex_us.unified_calibration import ( - CalibrationTarget, - UnifiedCalibrator, - calibrate_to_pe_targets, +_UNIFIED_CALIBRATION_EXPORTS = ( + "CalibrationTarget", + "UnifiedCalibrator", + "calibrate_to_pe_targets", +) +_VALIDATION_EXPORTS = ( + "AGI_BRACKETS", + "FILING_STATUSES", + "BaselineComparison", + "MetricComparison", + "SOITargets", + "ValidationResult", + "compute_baseline_comparison", + "compute_validation_metrics", + "export_comparison_json", + "get_soi_years", + "load_soi_targets", + "validate_against_soi", ) - -try: - from microplex_us.validation import ( - AGI_BRACKETS, - FILING_STATUSES, - BaselineComparison, - MetricComparison, - SOITargets, - ValidationResult, - compute_baseline_comparison, - compute_validation_metrics, - export_comparison_json, - get_soi_years, - load_soi_targets, - validate_against_soi, - ) -except ImportError: - AGI_BRACKETS = None - FILING_STATUSES = None - BaselineComparison = None - MetricComparison = None - SOITargets = None - ValidationResult = None - compute_baseline_comparison = None - compute_validation_metrics = None - export_comparison_json = None - get_soi_years = None - load_soi_targets = None - validate_against_soi = None - _LAZY_EXPORT_MODULES: dict[str, str] = { + **dict.fromkeys(_CALIBRATION_HARNESS_EXPORTS, "microplex_us.calibration_harness"), + **dict.fromkeys(_CPS_SYNTHETIC_EXPORTS, "microplex_us.cps_synthetic"), + **dict.fromkeys(_DATA_EXPORTS, "microplex_us.data"), + **dict.fromkeys(_GEOGRAPHY_EXPORTS, "microplex_us.geography"), + **dict.fromkeys(_HIERARCHICAL_EXPORTS, "microplex_us.hierarchical"), + **dict.fromkeys(_MICROPLEX_TARGET_EXPORTS, "microplex.targets"), + **dict.fromkeys(_PE_TARGETS_EXPORTS, "microplex_us.pe_targets"), **dict.fromkeys(_PIPELINE_EXPORTS, "microplex_us.pipelines"), **dict.fromkeys(_POLICYENGINE_EXPORTS, "microplex_us.policyengine"), **dict.fromkeys(_SOURCE_REGISTRY_EXPORTS, "microplex_us.source_registry"), + **dict.fromkeys(_TARGET_REGISTRY_EXPORTS, "microplex_us.target_registry"), **dict.fromkeys(_TARGETS_EXPORTS, "microplex_us.targets"), + **dict.fromkeys(_UNIFIED_CALIBRATION_EXPORTS, "microplex_us.unified_calibration"), + **dict.fromkeys(_VALIDATION_EXPORTS, "microplex_us.validation"), } diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py index 8c7e681..c467ed0 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py @@ -14,57 +14,16 @@ import h5py import numpy as np import pandas as pd -from microplex.core import ( - EntityObservation, - EntityType, - ObservationFrame, - SourceDescriptor, - SourceQuery, -) -from microplex.targets import assert_valid_benchmark_artifact_manifest - -from microplex_us.pipelines.artifacts import ( - USMicroplexArtifactPaths, - USMicroplexVersionedBuildArtifacts, - build_and_save_versioned_us_microplex_from_source_providers, -) -from microplex_us.pipelines.imputation_ablation import ( - ImputationAblationSliceSpec, - ImputationAblationVariant, - score_imputation_ablation_variants, -) -from microplex_us.pipelines.index_db import append_us_microplex_run_index_entry -from microplex_us.pipelines.pe_us_data_rebuild import ( - PEUSDataRebuildProgram, - default_policyengine_us_data_rebuild_config, - default_policyengine_us_data_rebuild_program, - default_policyengine_us_data_rebuild_source_providers, -) -from microplex_us.pipelines.pe_us_data_rebuild_audit import ( - build_policyengine_us_data_rebuild_native_audit, -) -from microplex_us.pipelines.pe_us_data_rebuild_parity import ( - build_policyengine_us_data_rebuild_parity_artifact, - write_policyengine_us_data_rebuild_parity_artifact, -) -from microplex_us.pipelines.registry import ( - append_us_microplex_run_registry_entry, - build_us_microplex_run_registry_entry, - load_us_microplex_run_registry, - select_us_microplex_frontier_entry, -) -from microplex_us.pipelines.stage_contracts import ( - resolve_us_stage_artifact_contract_path, -) -from microplex_us.pipelines.stage_run import ( - USStageInputOverride, - parse_us_stage_input_override, - write_us_stage_run_manifests_from_artifact_manifest, -) -from microplex_us.variables import prune_redundant_variables if TYPE_CHECKING: - from microplex.core import SourceProvider + from microplex.core import ( + EntityObservation, + EntityType, + ObservationFrame, + SourceDescriptor, + SourceProvider, + SourceQuery, + ) from microplex.targets import TargetProvider from microplex_us.pipelines.registry import FrontierMetric @@ -80,6 +39,143 @@ LOGGER = logging.getLogger(__name__) +_RUNTIME_SYMBOLS_LOADED = False + +EntityObservation: Any = None +EntityType: Any = None +ObservationFrame: Any = None +SourceDescriptor: Any = None +SourceQuery: Any = None +assert_valid_benchmark_artifact_manifest: Any = None +USMicroplexArtifactPaths: Any = None +USMicroplexVersionedBuildArtifacts: Any = None +build_and_save_versioned_us_microplex_from_source_providers: Any = None +ImputationAblationSliceSpec: Any = None +ImputationAblationVariant: Any = None +score_imputation_ablation_variants: Any = None +append_us_microplex_run_index_entry: Any = None +PEUSDataRebuildProgram: Any = None +default_policyengine_us_data_rebuild_config: Any = None +default_policyengine_us_data_rebuild_program: Any = None +default_policyengine_us_data_rebuild_source_providers: Any = None +build_policyengine_us_data_rebuild_native_audit: Any = None +build_policyengine_us_data_rebuild_parity_artifact: Any = None +write_policyengine_us_data_rebuild_parity_artifact: Any = None +append_us_microplex_run_registry_entry: Any = None +build_us_microplex_run_registry_entry: Any = None +load_us_microplex_run_registry: Any = None +select_us_microplex_frontier_entry: Any = None +resolve_us_stage_artifact_contract_path: Any = None +USStageInputOverride: Any = None +parse_us_stage_input_override: Any = None +write_us_stage_run_manifests_from_artifact_manifest: Any = None +prune_redundant_variables: Any = None + + +def _load_runtime_symbols() -> None: + """Import execution dependencies after CLI parsing has had a chance to exit.""" + + global _RUNTIME_SYMBOLS_LOADED + global EntityObservation, EntityType, ObservationFrame, SourceDescriptor, SourceQuery + global assert_valid_benchmark_artifact_manifest + global USMicroplexArtifactPaths, USMicroplexVersionedBuildArtifacts + global build_and_save_versioned_us_microplex_from_source_providers + global ImputationAblationSliceSpec, ImputationAblationVariant + global score_imputation_ablation_variants, append_us_microplex_run_index_entry + global PEUSDataRebuildProgram, default_policyengine_us_data_rebuild_config + global default_policyengine_us_data_rebuild_program + global default_policyengine_us_data_rebuild_source_providers + global build_policyengine_us_data_rebuild_native_audit + global build_policyengine_us_data_rebuild_parity_artifact + global write_policyengine_us_data_rebuild_parity_artifact + global append_us_microplex_run_registry_entry, build_us_microplex_run_registry_entry + global load_us_microplex_run_registry, select_us_microplex_frontier_entry + global resolve_us_stage_artifact_contract_path + global USStageInputOverride, parse_us_stage_input_override + global write_us_stage_run_manifests_from_artifact_manifest + global prune_redundant_variables + + if _RUNTIME_SYMBOLS_LOADED: + return + + import microplex.core as microplex_core + import microplex.targets as microplex_targets + + import microplex_us.pipelines.artifacts as artifacts_module + import microplex_us.pipelines.imputation_ablation as imputation_ablation_module + import microplex_us.pipelines.index_db as index_db_module + import microplex_us.pipelines.pe_us_data_rebuild as rebuild_module + import microplex_us.pipelines.pe_us_data_rebuild_audit as rebuild_audit_module + import microplex_us.pipelines.pe_us_data_rebuild_parity as rebuild_parity_module + import microplex_us.pipelines.registry as registry_module + import microplex_us.pipelines.stage_contracts as stage_contracts_module + import microplex_us.pipelines.stage_run as stage_run_module + import microplex_us.variables as variables_module + + EntityObservation = microplex_core.EntityObservation + EntityType = microplex_core.EntityType + ObservationFrame = microplex_core.ObservationFrame + SourceDescriptor = microplex_core.SourceDescriptor + SourceQuery = microplex_core.SourceQuery + assert_valid_benchmark_artifact_manifest = ( + microplex_targets.assert_valid_benchmark_artifact_manifest + ) + USMicroplexArtifactPaths = artifacts_module.USMicroplexArtifactPaths + USMicroplexVersionedBuildArtifacts = ( + artifacts_module.USMicroplexVersionedBuildArtifacts + ) + build_and_save_versioned_us_microplex_from_source_providers = ( + artifacts_module.build_and_save_versioned_us_microplex_from_source_providers + ) + ImputationAblationSliceSpec = imputation_ablation_module.ImputationAblationSliceSpec + ImputationAblationVariant = imputation_ablation_module.ImputationAblationVariant + score_imputation_ablation_variants = ( + imputation_ablation_module.score_imputation_ablation_variants + ) + append_us_microplex_run_index_entry = ( + index_db_module.append_us_microplex_run_index_entry + ) + PEUSDataRebuildProgram = rebuild_module.PEUSDataRebuildProgram + default_policyengine_us_data_rebuild_config = ( + rebuild_module.default_policyengine_us_data_rebuild_config + ) + default_policyengine_us_data_rebuild_program = ( + rebuild_module.default_policyengine_us_data_rebuild_program + ) + default_policyengine_us_data_rebuild_source_providers = ( + rebuild_module.default_policyengine_us_data_rebuild_source_providers + ) + build_policyengine_us_data_rebuild_native_audit = ( + rebuild_audit_module.build_policyengine_us_data_rebuild_native_audit + ) + build_policyengine_us_data_rebuild_parity_artifact = ( + rebuild_parity_module.build_policyengine_us_data_rebuild_parity_artifact + ) + write_policyengine_us_data_rebuild_parity_artifact = ( + rebuild_parity_module.write_policyengine_us_data_rebuild_parity_artifact + ) + append_us_microplex_run_registry_entry = ( + registry_module.append_us_microplex_run_registry_entry + ) + build_us_microplex_run_registry_entry = ( + registry_module.build_us_microplex_run_registry_entry + ) + load_us_microplex_run_registry = registry_module.load_us_microplex_run_registry + select_us_microplex_frontier_entry = ( + registry_module.select_us_microplex_frontier_entry + ) + resolve_us_stage_artifact_contract_path = ( + stage_contracts_module.resolve_us_stage_artifact_contract_path + ) + USStageInputOverride = stage_run_module.USStageInputOverride + parse_us_stage_input_override = stage_run_module.parse_us_stage_input_override + write_us_stage_run_manifests_from_artifact_manifest = ( + stage_run_module.write_us_stage_run_manifests_from_artifact_manifest + ) + prune_redundant_variables = variables_module.prune_redundant_variables + _RUNTIME_SYMBOLS_LOADED = True + + def _root_logger_has_handlers() -> bool: return bool(logging.getLogger().handlers) @@ -1432,6 +1528,8 @@ def attach_policyengine_us_data_rebuild_checkpoint_evidence( ) -> PEUSDataRebuildCheckpointEvidenceResult: """Attach PE comparison evidence to an already-saved rebuild artifact.""" + _load_runtime_symbols() + from microplex_us.pipelines.pe_native_scores import compute_us_pe_native_scores from microplex_us.policyengine.harness import evaluate_policyengine_us_harness from microplex_us.policyengine.us import load_policyengine_us_entity_tables @@ -1695,6 +1793,8 @@ def default_policyengine_us_data_rebuild_checkpoint_config( ) -> USMicroplexBuildConfig: """Return the canonical rebuild config with required PE comparison context.""" + _load_runtime_symbols() + resolved_target_period = int(target_period) resolved_baseline_weight_sum = _infer_policyengine_baseline_household_weight_sum( policyengine_baseline_dataset, @@ -1757,6 +1857,8 @@ def default_policyengine_us_data_rebuild_queries( ) -> dict[str, SourceQuery]: """Return default provider queries for a rebuild checkpoint smoke run.""" + _load_runtime_symbols() + from microplex_us.data_sources.cps import CPSASECSourceProvider from microplex_us.data_sources.donor_surveys import DonorSurveySourceProvider from microplex_us.data_sources.puf import PUFSourceProvider @@ -1867,6 +1969,8 @@ def run_policyengine_us_data_rebuild_checkpoint( ) -> PEUSDataRebuildCheckpointResult: """Run one saved rebuild checkpoint and write its PE comparison sidecars.""" + _load_runtime_symbols() + if config is not None and config_overrides: raise ValueError( "config_overrides cannot be used when an explicit config is supplied" @@ -2070,8 +2174,8 @@ def run_policyengine_us_data_rebuild_checkpoint( ) -def main(argv: list[str] | None = None) -> None: - """CLI entry point for one PE-US-data rebuild checkpoint.""" +def build_policyengine_us_data_rebuild_checkpoint_parser() -> argparse.ArgumentParser: + """Build the lightweight checkpoint CLI parser.""" parser = argparse.ArgumentParser( description="Run a versioned PE-US-data rebuild checkpoint in microplex-us." @@ -2252,11 +2356,15 @@ def main(argv: list[str] | None = None) -> None: metavar="STAGE_ID.KEY=PATH", help=("Explicit stage input override. Requires --allow-stage-input-overrides."), ) + return parser + + +def main(argv: list[str] | None = None) -> None: + """CLI entry point for one PE-US-data rebuild checkpoint.""" + + parser = build_policyengine_us_data_rebuild_checkpoint_parser() args = parser.parse_args(argv) - stage_input_overrides = tuple( - parse_us_stage_input_override(value) for value in args.stage_input_override - ) - if stage_input_overrides and not args.allow_stage_input_overrides: + if args.stage_input_override and not args.allow_stage_input_overrides: parser.error("--stage-input-override requires --allow-stage-input-overrides") config_overrides = { @@ -2300,6 +2408,11 @@ def main(argv: list[str] | None = None) -> None: args.capital_gains_lots_random_seed ) + _load_runtime_symbols() + stage_input_overrides = tuple( + parse_us_stage_input_override(value) for value in args.stage_input_override + ) + result = run_policyengine_us_data_rebuild_checkpoint( output_root=args.output_root, policyengine_baseline_dataset=args.baseline_dataset, diff --git a/tests/test_package_imports.py b/tests/test_package_imports.py index 7328cab..0f1e12d 100644 --- a/tests/test_package_imports.py +++ b/tests/test_package_imports.py @@ -21,6 +21,105 @@ def test_root_import_leaves_pipeline_exports_lazy() -> None: assert result.stdout.strip() == "False" +def test_root_import_does_not_require_torch_or_core_microplex() -> None: + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import importlib.abc\n" + "import sys\n" + "\n" + "class BlockTorch(importlib.abc.MetaPathFinder):\n" + " def find_spec(self, fullname, path=None, target=None):\n" + " if fullname == 'torch' or fullname.startswith('torch.'):\n" + " raise ModuleNotFoundError(\"No module named 'torch'\")\n" + " return None\n" + "\n" + "sys.meta_path.insert(0, BlockTorch())\n" + "import microplex_us\n" + "print('microplex' in sys.modules)\n" + "print('TargetSpec' in vars(microplex_us))\n" + ), + ], + check=True, + capture_output=True, + text=True, + ) + + assert result.stdout.splitlines() == ["False", "False"] + + +def test_pe_rebuild_checkpoint_import_does_not_require_torch() -> None: + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import importlib.abc\n" + "import sys\n" + "\n" + "class BlockTorch(importlib.abc.MetaPathFinder):\n" + " def find_spec(self, fullname, path=None, target=None):\n" + " if fullname == 'torch' or fullname.startswith('torch.'):\n" + " raise ModuleNotFoundError(\"No module named 'torch'\")\n" + " return None\n" + "\n" + "sys.meta_path.insert(0, BlockTorch())\n" + "import microplex_us.pipelines.pe_us_data_rebuild_checkpoint\n" + "print('microplex' in sys.modules)\n" + ), + ], + check=True, + capture_output=True, + text=True, + ) + + assert result.stdout.strip() == "False" + + +def test_pe_rebuild_checkpoint_help_does_not_require_torch_or_core_microplex() -> None: + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import importlib.abc\n" + "import runpy\n" + "import sys\n" + "\n" + "class BlockTorch(importlib.abc.MetaPathFinder):\n" + " def find_spec(self, fullname, path=None, target=None):\n" + " if fullname == 'torch' or fullname.startswith('torch.'):\n" + " raise ModuleNotFoundError(\"No module named 'torch'\")\n" + " return None\n" + "\n" + "sys.meta_path.insert(0, BlockTorch())\n" + "sys.argv = [\n" + " 'pe_us_data_rebuild_checkpoint',\n" + " '--help',\n" + "]\n" + "try:\n" + " runpy.run_module(\n" + " 'microplex_us.pipelines.pe_us_data_rebuild_checkpoint',\n" + " run_name='__main__',\n" + " )\n" + "except SystemExit as exc:\n" + " print(f'exit={exc.code}')\n" + "print(f'microplex_imported={\"microplex\" in sys.modules}')\n" + ), + ], + check=True, + capture_output=True, + text=True, + ) + + assert result.stdout.splitlines()[-2:] == [ + "exit=0", + "microplex_imported=False", + ] + + def test_data_sources_import_leaves_family_benchmark_lazy() -> None: result = subprocess.run( [