From 4802bad5b903728b85797d86a9c3e7b12929daae Mon Sep 17 00:00:00 2001 From: juaristi22 Date: Tue, 5 Aug 2025 18:58:27 +0200 Subject: [PATCH] return df directly when no quantiles are specified --- changelog_entry.yaml | 4 +++ microimpute/models/imputer.py | 38 +++++++++++++++----------- microimpute/models/matching.py | 19 ++++--------- microimpute/models/ols.py | 9 +++--- microimpute/models/qrf.py | 6 +++- microimpute/models/quantreg.py | 7 ++++- tests/test_models/test_imputers.py | 8 +++--- tests/test_models/test_matching.py | 4 --- tests/test_models/test_qrf.py | 4 +-- tests/test_models/test_qrf_extended.py | 37 +++++++++++-------------- 10 files changed, 69 insertions(+), 67 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..5b53726 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + added: + - Make models return a dataframe directly when no quantiles are specified. diff --git a/microimpute/models/imputer.py b/microimpute/models/imputer.py index c2e5077..30f48b5 100644 --- a/microimpute/models/imputer.py +++ b/microimpute/models/imputer.py @@ -750,9 +750,9 @@ def _preprocess_data_types( @validate_call(config=VALIDATE_CONFIG) def _postprocess_imputations( self, - imputations: Dict[float, pd.DataFrame], + imputations: Union[Dict[float, pd.DataFrame], pd.DataFrame], dummy_info: Dict[str, Any], - ) -> Dict[float, pd.DataFrame]: + ) -> Union[Dict[float, pd.DataFrame], pd.DataFrame]: """Convert imputed bool and categorical dummy variables back to original data types. This function reverses the encoding applied by preprocess_data, @@ -815,10 +815,10 @@ def _get_reference_category( try: processed_imputations = {} - for quantile, df in imputations.items(): - self.logger.debug( - f"Processing quantile {quantile} with shape {df.shape}" - ) + def process_single_quantile( + df: pd.DataFrame, dummy_info: Dict[str, Any] + ) -> pd.DataFrame: + df_processed = df.copy() for orig_col, dummy_cols in dummy_info.get( @@ -1021,16 +1021,22 @@ def _get_reference_category( self.logger.warning( f"No dummy columns found for categorical variable {orig_col}" ) - - processed_imputations[quantile] = df_processed - self.logger.debug( - f"Processed quantile {quantile}, final shape: {df_processed.shape}" - ) - - self.logger.info( - f"Successfully post-processed {len(processed_imputations)} quantile imputations" - ) - return processed_imputations + return df_processed + + if isinstance(imputations, pd.DataFrame): + processed_df = process_single_quantile(imputations, dummy_info) + return processed_df + else: + for quantile, df in imputations.items(): + self.logger.debug( + f"Processing quantile {quantile} with shape {df.shape}" + ) + processed_df = process_single_quantile(df, dummy_info) + processed_imputations[quantile] = processed_df + self.logger.debug( + f"Processed quantile {quantile}, final shape: {processed_df.shape}" + ) + return processed_imputations except Exception as e: self.logger.error(f"Error in postprocess_imputations: {str(e)}") diff --git a/microimpute/models/matching.py b/microimpute/models/matching.py index 167a995..6ec0dc8 100644 --- a/microimpute/models/matching.py +++ b/microimpute/models/matching.py @@ -287,29 +287,20 @@ def _process_matching_results( imputed_df[variable] = fused0[variable].values imputations[q] = imputed_df + return imputations else: # If no quantiles specified, use a default one - q = 0.5 + q_default = 0.5 self.logger.info( - f"Creating imputation for default quantile {q}" + f"Creating imputation for default quantile {q_default}" ) imputed_df = pd.DataFrame(index=X_test_copy.index) for variable in self.imputed_variables: self.logger.info(f"Imputing variable {variable}") imputed_df[variable] = fused0[variable].values - imputations[q] = imputed_df + imputations[q_default] = imputed_df - # Verify output shapes - for q, df in imputations.items(): - self.logger.debug( - f"Imputation result for q={q}: shape={df.shape}" - ) - if len(df) != len(X_test_copy): - self.logger.warning( - f"Result shape mismatch: expected {len(X_test_copy)} rows, got {len(df)}" - ) - - return imputations + return imputations[q_default] except Exception as output_error: self.logger.error( f"Error creating output imputations: {str(output_error)}" diff --git a/microimpute/models/ols.py b/microimpute/models/ols.py index b5c68d4..5ab7342 100644 --- a/microimpute/models/ols.py +++ b/microimpute/models/ols.py @@ -96,8 +96,9 @@ def _predict( random_sample=random_quantile_sample, ) imputations[q] = pd.DataFrame(imputed_df) + return imputations else: - q = 0.5 + q_default = 0.5 imputed_df = pd.DataFrame() for variable in self.imputed_variables: self.logger.info(f"Imputing variable {variable}") @@ -107,11 +108,11 @@ def _predict( imputed_df[variable] = self._predict_quantile( mean_preds=mean_preds, se=se, - mean_quantile=q, + mean_quantile=q_default, random_sample=random_quantile_sample, ) - imputations[q] = pd.DataFrame(imputed_df) - return imputations + imputations[q_default] = pd.DataFrame(imputed_df) + return imputations[q_default] except Exception as e: self.logger.error(f"Error during prediction: {str(e)}") diff --git a/microimpute/models/qrf.py b/microimpute/models/qrf.py index 631faa8..2777e68 100644 --- a/microimpute/models/qrf.py +++ b/microimpute/models/qrf.py @@ -228,7 +228,11 @@ def _predict( imputations[q] = imputed_df - return imputations + qs = imputations.keys() + if len(qs) < 2: + q = list(qs)[0] + + return imputations if quantiles else imputations[q] except Exception as e: self.logger.error(f"Error during QRF prediction: {str(e)}") diff --git a/microimpute/models/quantreg.py b/microimpute/models/quantreg.py index ae27d06..b98755b 100644 --- a/microimpute/models/quantreg.py +++ b/microimpute/models/quantreg.py @@ -161,7 +161,12 @@ def _predict( self.logger.info( f"Completed predictions for {len(quantiles)} quantiles" ) - return imputations + + quantiles = imputations.keys() + if len(quantiles) < 2: + q = list(quantiles)[0] + + return imputations if len(imputations) > 1 else imputations[q] except ValueError as e: # Re-raise value errors directly diff --git a/tests/test_models/test_imputers.py b/tests/test_models/test_imputers.py index 3b4a550..6ff7d19 100644 --- a/tests/test_models/test_imputers.py +++ b/tests/test_models/test_imputers.py @@ -129,8 +129,8 @@ def test_fit_predict_interface( ) default_predictions = fitted_default_model.predict(X_test) - assert isinstance(default_predictions, dict), ( - f"{model_class.__name__} predict should return a dictionary even with " + assert isinstance(default_predictions, pd.DataFrame), ( + f"{model_class.__name__} predict should return a DataFrame with " f"default quantiles" ) @@ -171,8 +171,8 @@ def test_imputation_categorical_bool_vars() -> None: fitted_ols = ols.fit(X_train, predictors, imputed_variables) ols_predictions = fitted_ols.predict(X_test) - assert ols_predictions[0.5]["categorical"].dtype == "object" - assert ols_predictions[0.5]["bool"].dtype == "bool" + assert ols_predictions["categorical"].dtype == "object" + assert ols_predictions["bool"].dtype == "bool" @pytest.mark.parametrize( diff --git a/tests/test_models/test_matching.py b/tests/test_models/test_matching.py index 0a79e89..9b4424a 100644 --- a/tests/test_models/test_matching.py +++ b/tests/test_models/test_matching.py @@ -106,10 +106,6 @@ def test_matching_example_use( ) # Check structure of predictions - assert isinstance(predictions, dict) - assert 0.5 in predictions - - # Check that predictions are pandas DataFrame for matching model assert isinstance(predictions[0.5], pd.DataFrame) transformed_df = pd.DataFrame() diff --git a/tests/test_models/test_qrf.py b/tests/test_models/test_qrf.py index 17ac05a..bf275f4 100644 --- a/tests/test_models/test_qrf.py +++ b/tests/test_models/test_qrf.py @@ -296,8 +296,8 @@ def test_qrf_imputes_multiple_variables( predictions: Dict[float, pd.DataFrame] = fitted_model.predict(X_test) # Check structure of predictions - assert isinstance(predictions, dict) - assert predictions[0.5].shape[1] == len(imputed_variables) + assert isinstance(predictions, pd.DataFrame) + assert predictions.shape[1] == len(imputed_variables) def test_qrf_sequential_imputation( diff --git a/tests/test_models/test_qrf_extended.py b/tests/test_models/test_qrf_extended.py index 50a4397..458ea01 100644 --- a/tests/test_models/test_qrf_extended.py +++ b/tests/test_models/test_qrf_extended.py @@ -140,9 +140,8 @@ def test_qrf_with_missing_categorical_columns_in_test(): # Predict - should handle missing category gracefully predictions = fitted_model.predict(test_data[["numeric", "category"]]) - assert 0.5 in predictions - assert len(predictions[0.5]) == len(test_data) - assert not predictions[0.5]["target"].isna().any() + assert len(predictions) == len(test_data) + assert not predictions["target"].isna().any() def test_qrf_with_single_predictor(): @@ -248,8 +247,8 @@ def test_qrf_reproducibility(): # Results should be identical np.testing.assert_array_almost_equal( - predictions1[0.5]["y"].values, - predictions2[0.5]["y"].values, + predictions1["y"].values, + predictions2["y"].values, ) @@ -286,12 +285,12 @@ def test_qrf_with_highly_correlated_predictors(): predictions = fitted_model.predict(test_data[["x1", "x2", "x3"]]) # Model should still produce reasonable predictions despite correlation - assert len(predictions[0.5]) == len(test_data) - assert not predictions[0.5]["y"].isna().any() + assert len(predictions) == len(test_data) + assert not predictions["y"].isna().any() # Check that predictions are somewhat correlated with true values true_y = test_data["y"].values - pred_y = predictions[0.5]["y"].values + pred_y = predictions["y"].values correlation = np.corrcoef(true_y, pred_y)[0, 1] assert correlation > 0.5 # Should have reasonable correlation @@ -507,9 +506,8 @@ def test_qrf_batch_processing(): test_data = data[["predictor1", "predictor2"]].head(5) predictions = fitted_model.predict(test_data) - assert 0.5 in predictions - assert len(predictions[0.5]) == len(test_data) - assert not predictions[0.5].isna().any().any() + assert len(predictions) == len(test_data) + assert not predictions.isna().any().any() # Clean up model.logger.removeHandler(handler) @@ -698,9 +696,8 @@ def test_qrf_missing_variables_handling(): test_data = data[["x1", "x2"]].head(10) predictions = fitted_model.predict(test_data) - assert 0.5 in predictions - assert "existing_var" in predictions[0.5].columns - assert len(predictions[0.5]) == len(test_data) + assert "existing_var" in predictions.columns + assert len(predictions) == len(test_data) # Clean up model_lenient.logger.removeHandler(handler) @@ -749,10 +746,9 @@ def test_qrf_all_variables_missing(): test_data = data[["x1", "x2"]].head(5) predictions = fitted_model.predict(test_data) - assert 0.5 in predictions - assert len(predictions[0.5].columns) == 0 # No variables to predict + assert len(predictions.columns) == 0 # No variables to predict # When there are no variables to impute, predictions should be empty but defined - assert isinstance(predictions[0.5], pd.DataFrame) + assert isinstance(predictions, pd.DataFrame) # Clean up model.logger.removeHandler(handler) @@ -815,10 +811,9 @@ def test_qrf_partial_missing_variables(): test_data = data[["predictor1", "predictor2"]].head(8) predictions = fitted_model.predict(test_data) - assert 0.5 in predictions - assert "target1" in predictions[0.5].columns - assert "target3" in predictions[0.5].columns - assert "target2" not in predictions[0.5].columns + assert "target1" in predictions.columns + assert "target3" in predictions.columns + assert "target2" not in predictions.columns # Clean up model.logger.removeHandler(handler)