diff --git a/flaml/automl/data.py b/flaml/automl/data.py index a9ba323b3c..c27adc2879 100644 --- a/flaml/automl/data.py +++ b/flaml/automl/data.py @@ -7,6 +7,7 @@ import random import re import uuid +import warnings from datetime import datetime, timedelta from decimal import ROUND_HALF_UP, Decimal from typing import TYPE_CHECKING, Union @@ -350,6 +351,17 @@ def fit_transform(self, X: Union[DataFrame, np.ndarray], y, task: Union[str, "Ta X.insert(0, TS_TIMESTAMP_COL, ds_col) if cat_columns: X[cat_columns] = X[cat_columns].astype("category") + # Pin the per-column category list seen at fit time so + # `transform()` produces the same integer codes for the same + # values regardless of what is passed at predict time (see + # issue #1101). "__NAN__" is reserved as the sentinel slot + # used for values unseen at fit time. + self._cat_categories = {} + for col in cat_columns: + cats = list(X[col].cat.categories) + if "__NAN__" not in cats: + cats.append("__NAN__") + self._cat_categories[col] = cats if num_columns: X_num = X[num_columns] try: @@ -450,6 +462,30 @@ def transform(self, X: Union[DataFrame, np.array]): X[column] = X[column].cat.add_categories("__NAN__").fillna("__NAN__") if cat_columns: X[cat_columns] = X[cat_columns].astype("category") + # Pin codes to the categories seen at fit time so they do not + # drift when the predict-time column has a different value + # distribution than the fit-time column (see issue #1101). + # Older pickles without `_cat_categories` fall back to + # whatever `astype("category")` inferred above. + saved_cats_map = getattr(self, "_cat_categories", None) + if saved_cats_map: + for column in cat_columns: + saved_cats = saved_cats_map.get(column) + if saved_cats is None: + continue + current = X[column].astype(object) + unseen_mask = ~current.isin(saved_cats) & current.notna() + if unseen_mask.any(): + samples = sorted({str(v) for v in current[unseen_mask].unique()})[:5] + warnings.warn( + f"Column '{column}' contains values unseen at fit time " + f"(e.g. {samples}); these rows will be encoded as '__NAN__' " + "and predictions may be unreliable.", + UserWarning, + stacklevel=2, + ) + current = current.where(~unseen_mask, "__NAN__") + X[column] = pd.Categorical(current, categories=saved_cats) if num_columns: X_num = X[num_columns].fillna(np.nan) if self._drop: diff --git a/test/automl/test_preprocess_api.py b/test/automl/test_preprocess_api.py index 45b9c6143b..5b207595e7 100644 --- a/test/automl/test_preprocess_api.py +++ b/test/automl/test_preprocess_api.py @@ -232,5 +232,65 @@ def test_estimator_preprocess_without_automl(self): self.assertEqual(X_preprocessed.shape, X_test.shape) +class TestCategoricalEncodingStability(unittest.TestCase): + """Regression coverage for #1101 — DataTransformer must produce the same + integer code for the same categorical value regardless of which values + happen to be present in the predict-time DataFrame.""" + + def _fit_simple(self): + from flaml.automl.data import DataTransformer + from flaml.automl.task.factory import task_factory + + rng = np.random.RandomState(0) + n = 100 + fit_df = pd.DataFrame({"a": rng.randn(n), "gender": rng.choice(["M", "F"], n)}) + fit_y = pd.Series(rng.randn(n)) + + transformer = DataTransformer() + task = task_factory("regression", fit_df, fit_y) + X_fit, _ = transformer.fit_transform(fit_df.copy(), fit_y, task) + return transformer, X_fit + + def test_codes_stable_when_predict_uses_only_a_subset(self): + transformer, X_fit = self._fit_simple() + fit_code_for_M = int(X_fit["gender"].cat.codes[X_fit["gender"] == "M"].iloc[0]) + + # Predict-time DataFrame contains only "M" rows. + predict_df = pd.DataFrame({"a": np.zeros(20), "gender": ["M"] * 20}) + X_pred = transformer.transform(predict_df.copy()) + pred_code_for_M = int(X_pred["gender"].cat.codes[X_pred["gender"] == "M"].iloc[0]) + + self.assertEqual( + fit_code_for_M, + pred_code_for_M, + "categorical code for 'M' drifted between fit and predict — see #1101", + ) + + def test_unseen_categories_emit_warning_and_map_to_sentinel(self): + import warnings + + transformer, _ = self._fit_simple() + predict_df = pd.DataFrame({"a": np.zeros(5), "gender": ["M", "F", "X", "M", "Y"]}) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + X_pred = transformer.transform(predict_df.copy()) + + unseen_warnings = [ + w for w in caught if issubclass(w.category, UserWarning) and "unseen at fit time" in str(w.message) + ] + self.assertEqual(len(unseen_warnings), 1) + message = str(unseen_warnings[0].message) + self.assertIn("gender", message) + self.assertIn("X", message) + self.assertIn("Y", message) + + # Unseen "X" and "Y" rows must be encoded as the "__NAN__" sentinel slot + # and seen "F" / "M" codes must still match fit-time codes. + nan_code = list(X_pred["gender"].cat.categories).index("__NAN__") + unseen_rows = X_pred["gender"].cat.codes[predict_df["gender"].isin(["X", "Y"]).values] + self.assertTrue((unseen_rows == nan_code).all()) + + if __name__ == "__main__": unittest.main()