Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/target-filters.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added common and target-specific row filters for QRF training.
123 changes: 110 additions & 13 deletions microimpute/models/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np
import pandas as pd
from pydantic import SkipValidation, validate_call
from pydantic import validate_call

from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG
from microimpute.utils.type_handling import (
Expand Down Expand Up @@ -104,6 +104,7 @@ def identify_target_types(
data: pd.DataFrame,
imputed_variables: List[str],
not_numeric_categorical: Optional[List[str]] = None,
target_fit_masks: Optional[Dict[str, pd.Series]] = None,
) -> None:
"""Identify and track variable types for imputation targets.

Expand All @@ -113,21 +114,27 @@ def identify_target_types(
not_numeric_categorical: Optional list of variable names that should
be treated as numeric even if they would normally be detected as
numeric_categorical.
target_fit_masks: Optional target-specific row masks to use when
inferring target type and constants.
"""
detector = VariableTypeDetector()
not_numeric_categorical = not_numeric_categorical or []
target_fit_masks = target_fit_masks or {}

for var in imputed_variables:
if var not in data.columns:
continue
target_data = data[var]
if var in target_fit_masks:
target_data = target_data.loc[target_fit_masks[var]]

# First check if the variable has a constant value
unique_values = data[var].dropna().unique()
unique_values = target_data.dropna().unique()
if len(unique_values) == 1:
constant_val = unique_values[0]
self.constant_targets[var] = {
"value": constant_val,
"dtype": data[var].dtype,
"dtype": target_data.dtype,
}
self.logger.warning(
f"Target variable '{var}' has constant value {constant_val}. "
Expand All @@ -136,7 +143,7 @@ def identify_target_types(
continue

var_type, categories = detector.categorize_variable(
data[var],
target_data,
var,
self.logger,
force_numeric=(var in not_numeric_categorical),
Expand All @@ -145,15 +152,15 @@ def identify_target_types(
if var_type == "bool":
self.boolean_targets[var] = {
"type": "boolean",
"dtype": data[var].dtype,
"dtype": target_data.dtype,
}
self.logger.info(f"Identified boolean target: {var}")

elif var_type in ["categorical", "numeric_categorical"]:
self.categorical_targets[var] = {
"type": var_type,
"categories": categories,
"dtype": data[var].dtype,
"dtype": target_data.dtype,
}
self.logger.info(
f"Identified categorical target: {var} with {len(categories) if categories else 0} categories"
Expand All @@ -163,6 +170,30 @@ def identify_target_types(
self.numeric_targets.append(var)
self.logger.debug(f"Identified numeric target: {var}")

def _coerce_fit_filter(
self,
X_train: pd.DataFrame,
fit_filter: Union[str, np.ndarray, pd.Series, List[bool], Tuple[bool, ...]],
*,
name: str,
) -> pd.Series:
"""Normalize a row-filter input to a boolean Series on ``X_train``."""
if isinstance(fit_filter, str):
if fit_filter not in X_train.columns:
raise ValueError(f"{name} column '{fit_filter}' not found in X_train")
mask = X_train[fit_filter]
elif isinstance(fit_filter, pd.Series):
mask = fit_filter.reindex(X_train.index)
else:
mask = pd.Series(fit_filter, index=X_train.index)

if len(mask) != len(X_train):
raise ValueError(f"{name} must have length {len(X_train)}, got {len(mask)}")
if mask.isna().any():
raise ValueError(f"{name} contains missing values")

return mask.astype(bool)

@validate_call(config=VALIDATE_CONFIG)
def preprocess_data_types(
self,
Expand Down Expand Up @@ -216,6 +247,12 @@ def fit(
weight_col: Optional[Union[str, np.ndarray, pd.Series]] = None,
skip_missing: bool = False,
not_numeric_categorical: Optional[List[str]] = None,
row_filter: Optional[
Union[str, np.ndarray, pd.Series, List[bool], Tuple[bool, ...]]
] = None,
target_filters: Optional[
Dict[str, Union[str, np.ndarray, pd.Series, List[bool], Tuple[bool, ...]]]
] = None,
**kwargs: Any,
) -> Any: # Returns ImputerResults
"""Fit the model to the training data.
Expand All @@ -229,6 +266,13 @@ def fit(
not_numeric_categorical: Optional list of variable names that should
be treated as numeric even if they would normally be detected as
numeric_categorical.
row_filter: Optional common row mask, or the name of a boolean
column in X_train, selecting rows eligible for all targets.
target_filters: Optional mapping from imputed variable name to a
target-specific row mask, or the name of a boolean column in
X_train. Target-specific filters are combined with row_filter.
They are supported by models that fit one model per target,
such as QRF.
**kwargs: Additional model-specific parameters.

Returns:
Expand All @@ -240,6 +284,48 @@ def fit(
NotImplementedError: If method is not implemented by subclass.
"""
original_predictors = predictors.copy()
target_filters = target_filters or {}
unknown_target_filters = set(target_filters) - set(imputed_variables)
if unknown_target_filters:
raise ValueError(
"target_filters contains variables not in imputed_variables: "
f"{sorted(unknown_target_filters)}"
)

base_mask = pd.Series(True, index=X_train.index)
if row_filter is not None:
base_mask = self._coerce_fit_filter(
X_train,
row_filter,
name="row_filter",
)

target_fit_masks = {}
for variable, target_filter in target_filters.items():
target_fit_masks[variable] = (
self._coerce_fit_filter(
X_train,
target_filter,
name=f"target_filters[{variable!r}]",
)
& base_mask
)

if target_filters and not getattr(self, "supports_target_filters", False):
raise NotImplementedError(
f"{type(self).__name__} does not support target_filters"
)

if not base_mask.all():
if isinstance(weight_col, np.ndarray):
weight_col = pd.Series(weight_col, index=base_mask.index).loc[base_mask]
elif isinstance(weight_col, pd.Series):
weight_col = weight_col.reindex(base_mask.index).loc[base_mask]
X_train = X_train.loc[base_mask].copy()
target_fit_masks = {
variable: mask.loc[X_train.index]
for variable, mask in target_fit_masks.items()
}

try:
# Handle missing variables if skip_missing is enabled
Expand Down Expand Up @@ -288,7 +374,12 @@ def fit(
)

# Identify target types BEFORE preprocessing
self.identify_target_types(X_train, imputed_variables, not_numeric_categorical)
self.identify_target_types(
X_train,
imputed_variables,
not_numeric_categorical,
target_fit_masks=target_fit_masks,
)

X_train, predictors, imputed_variables, imputed_vars_dummy_info = (
self.preprocess_data_types(
Expand Down Expand Up @@ -319,17 +410,23 @@ def fit(
)

# Defer actual training to subclass with all parameters
fit_kwargs = {
"categorical_targets": self.categorical_targets,
"boolean_targets": self.boolean_targets,
"numeric_targets": self.numeric_targets,
"constant_targets": self.constant_targets,
"sample_weight": sample_weight,
**kwargs,
}
if target_fit_masks:
fit_kwargs["target_fit_masks"] = target_fit_masks

fitted_model = self._fit(
X_train,
self.predictors,
self.imputed_variables,
self.original_predictors,
categorical_targets=self.categorical_targets,
boolean_targets=self.boolean_targets,
numeric_targets=self.numeric_targets,
constant_targets=self.constant_targets,
sample_weight=sample_weight,
**kwargs,
**fit_kwargs,
)
return fitted_model

Expand Down
Loading
Loading