Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
added:
- Make models return a dataframe directly when no quantiles are specified.
38 changes: 22 additions & 16 deletions microimpute/models/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,9 @@ def _preprocess_data_types(
@validate_call(config=VALIDATE_CONFIG)
def _postprocess_imputations(
self,
imputations: Dict[float, pd.DataFrame],
imputations: Union[Dict[float, pd.DataFrame], pd.DataFrame],
dummy_info: Dict[str, Any],
) -> Dict[float, pd.DataFrame]:
) -> Union[Dict[float, pd.DataFrame], pd.DataFrame]:
"""Convert imputed bool and categorical dummy variables back to original data types.

This function reverses the encoding applied by preprocess_data,
Expand Down Expand Up @@ -815,10 +815,10 @@ def _get_reference_category(
try:
processed_imputations = {}

for quantile, df in imputations.items():
self.logger.debug(
f"Processing quantile {quantile} with shape {df.shape}"
)
def process_single_quantile(
df: pd.DataFrame, dummy_info: Dict[str, Any]
) -> pd.DataFrame:

df_processed = df.copy()

for orig_col, dummy_cols in dummy_info.get(
Expand Down Expand Up @@ -1021,16 +1021,22 @@ def _get_reference_category(
self.logger.warning(
f"No dummy columns found for categorical variable {orig_col}"
)

processed_imputations[quantile] = df_processed
self.logger.debug(
f"Processed quantile {quantile}, final shape: {df_processed.shape}"
)

self.logger.info(
f"Successfully post-processed {len(processed_imputations)} quantile imputations"
)
return processed_imputations
return df_processed

if isinstance(imputations, pd.DataFrame):
processed_df = process_single_quantile(imputations, dummy_info)
return processed_df
else:
for quantile, df in imputations.items():
self.logger.debug(
f"Processing quantile {quantile} with shape {df.shape}"
)
processed_df = process_single_quantile(df, dummy_info)
processed_imputations[quantile] = processed_df
self.logger.debug(
f"Processed quantile {quantile}, final shape: {processed_df.shape}"
)
return processed_imputations

except Exception as e:
self.logger.error(f"Error in postprocess_imputations: {str(e)}")
Expand Down
19 changes: 5 additions & 14 deletions microimpute/models/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,29 +287,20 @@ def _process_matching_results(
imputed_df[variable] = fused0[variable].values

imputations[q] = imputed_df
return imputations
else:
# If no quantiles specified, use a default one
q = 0.5
q_default = 0.5
self.logger.info(
f"Creating imputation for default quantile {q}"
f"Creating imputation for default quantile {q_default}"
)
imputed_df = pd.DataFrame(index=X_test_copy.index)
for variable in self.imputed_variables:
self.logger.info(f"Imputing variable {variable}")
imputed_df[variable] = fused0[variable].values
imputations[q] = imputed_df
imputations[q_default] = imputed_df

# Verify output shapes
for q, df in imputations.items():
self.logger.debug(
f"Imputation result for q={q}: shape={df.shape}"
)
if len(df) != len(X_test_copy):
self.logger.warning(
f"Result shape mismatch: expected {len(X_test_copy)} rows, got {len(df)}"
)

return imputations
return imputations[q_default]
except Exception as output_error:
self.logger.error(
f"Error creating output imputations: {str(output_error)}"
Expand Down
9 changes: 5 additions & 4 deletions microimpute/models/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def _predict(
random_sample=random_quantile_sample,
)
imputations[q] = pd.DataFrame(imputed_df)
return imputations
else:
q = 0.5
q_default = 0.5
imputed_df = pd.DataFrame()
for variable in self.imputed_variables:
self.logger.info(f"Imputing variable {variable}")
Expand All @@ -107,11 +108,11 @@ def _predict(
imputed_df[variable] = self._predict_quantile(
mean_preds=mean_preds,
se=se,
mean_quantile=q,
mean_quantile=q_default,
random_sample=random_quantile_sample,
)
imputations[q] = pd.DataFrame(imputed_df)
return imputations
imputations[q_default] = pd.DataFrame(imputed_df)
return imputations[q_default]

except Exception as e:
self.logger.error(f"Error during prediction: {str(e)}")
Expand Down
6 changes: 5 additions & 1 deletion microimpute/models/qrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ def _predict(

imputations[q] = imputed_df

return imputations
qs = imputations.keys()
if len(qs) < 2:
q = list(qs)[0]

return imputations if quantiles else imputations[q]

except Exception as e:
self.logger.error(f"Error during QRF prediction: {str(e)}")
Expand Down
7 changes: 6 additions & 1 deletion microimpute/models/quantreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ def _predict(
self.logger.info(
f"Completed predictions for {len(quantiles)} quantiles"
)
return imputations

quantiles = imputations.keys()
if len(quantiles) < 2:
q = list(quantiles)[0]

return imputations if len(imputations) > 1 else imputations[q]

except ValueError as e:
# Re-raise value errors directly
Expand Down
8 changes: 4 additions & 4 deletions tests/test_models/test_imputers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def test_fit_predict_interface(
)

default_predictions = fitted_default_model.predict(X_test)
assert isinstance(default_predictions, dict), (
f"{model_class.__name__} predict should return a dictionary even with "
assert isinstance(default_predictions, pd.DataFrame), (
f"{model_class.__name__} predict should return a DataFrame with "
f"default quantiles"
)

Expand Down Expand Up @@ -171,8 +171,8 @@ def test_imputation_categorical_bool_vars() -> None:
fitted_ols = ols.fit(X_train, predictors, imputed_variables)
ols_predictions = fitted_ols.predict(X_test)

assert ols_predictions[0.5]["categorical"].dtype == "object"
assert ols_predictions[0.5]["bool"].dtype == "bool"
assert ols_predictions["categorical"].dtype == "object"
assert ols_predictions["bool"].dtype == "bool"


@pytest.mark.parametrize(
Expand Down
4 changes: 0 additions & 4 deletions tests/test_models/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def test_matching_example_use(
)

# Check structure of predictions
assert isinstance(predictions, dict)
assert 0.5 in predictions

# Check that predictions are pandas DataFrame for matching model
assert isinstance(predictions[0.5], pd.DataFrame)

transformed_df = pd.DataFrame()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_qrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def test_qrf_imputes_multiple_variables(
predictions: Dict[float, pd.DataFrame] = fitted_model.predict(X_test)

# Check structure of predictions
assert isinstance(predictions, dict)
assert predictions[0.5].shape[1] == len(imputed_variables)
assert isinstance(predictions, pd.DataFrame)
assert predictions.shape[1] == len(imputed_variables)


def test_qrf_sequential_imputation(
Expand Down
37 changes: 16 additions & 21 deletions tests/test_models/test_qrf_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ def test_qrf_with_missing_categorical_columns_in_test():
# Predict - should handle missing category gracefully
predictions = fitted_model.predict(test_data[["numeric", "category"]])

assert 0.5 in predictions
assert len(predictions[0.5]) == len(test_data)
assert not predictions[0.5]["target"].isna().any()
assert len(predictions) == len(test_data)
assert not predictions["target"].isna().any()


def test_qrf_with_single_predictor():
Expand Down Expand Up @@ -248,8 +247,8 @@ def test_qrf_reproducibility():

# Results should be identical
np.testing.assert_array_almost_equal(
predictions1[0.5]["y"].values,
predictions2[0.5]["y"].values,
predictions1["y"].values,
predictions2["y"].values,
)


Expand Down Expand Up @@ -286,12 +285,12 @@ def test_qrf_with_highly_correlated_predictors():
predictions = fitted_model.predict(test_data[["x1", "x2", "x3"]])

# Model should still produce reasonable predictions despite correlation
assert len(predictions[0.5]) == len(test_data)
assert not predictions[0.5]["y"].isna().any()
assert len(predictions) == len(test_data)
assert not predictions["y"].isna().any()

# Check that predictions are somewhat correlated with true values
true_y = test_data["y"].values
pred_y = predictions[0.5]["y"].values
pred_y = predictions["y"].values
correlation = np.corrcoef(true_y, pred_y)[0, 1]
assert correlation > 0.5 # Should have reasonable correlation

Expand Down Expand Up @@ -507,9 +506,8 @@ def test_qrf_batch_processing():
test_data = data[["predictor1", "predictor2"]].head(5)
predictions = fitted_model.predict(test_data)

assert 0.5 in predictions
assert len(predictions[0.5]) == len(test_data)
assert not predictions[0.5].isna().any().any()
assert len(predictions) == len(test_data)
assert not predictions.isna().any().any()

# Clean up
model.logger.removeHandler(handler)
Expand Down Expand Up @@ -698,9 +696,8 @@ def test_qrf_missing_variables_handling():
test_data = data[["x1", "x2"]].head(10)
predictions = fitted_model.predict(test_data)

assert 0.5 in predictions
assert "existing_var" in predictions[0.5].columns
assert len(predictions[0.5]) == len(test_data)
assert "existing_var" in predictions.columns
assert len(predictions) == len(test_data)

# Clean up
model_lenient.logger.removeHandler(handler)
Expand Down Expand Up @@ -749,10 +746,9 @@ def test_qrf_all_variables_missing():
test_data = data[["x1", "x2"]].head(5)
predictions = fitted_model.predict(test_data)

assert 0.5 in predictions
assert len(predictions[0.5].columns) == 0 # No variables to predict
assert len(predictions.columns) == 0 # No variables to predict
# When there are no variables to impute, predictions should be empty but defined
assert isinstance(predictions[0.5], pd.DataFrame)
assert isinstance(predictions, pd.DataFrame)

# Clean up
model.logger.removeHandler(handler)
Expand Down Expand Up @@ -815,10 +811,9 @@ def test_qrf_partial_missing_variables():
test_data = data[["predictor1", "predictor2"]].head(8)
predictions = fitted_model.predict(test_data)

assert 0.5 in predictions
assert "target1" in predictions[0.5].columns
assert "target3" in predictions[0.5].columns
assert "target2" not in predictions[0.5].columns
assert "target1" in predictions.columns
assert "target3" in predictions.columns
assert "target2" not in predictions.columns

# Clean up
model.logger.removeHandler(handler)