Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions flaml/automl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import re
import uuid
import warnings
from datetime import datetime, timedelta
from decimal import ROUND_HALF_UP, Decimal
from typing import TYPE_CHECKING, Union
Expand Down Expand Up @@ -350,6 +351,17 @@ def fit_transform(self, X: Union[DataFrame, np.ndarray], y, task: Union[str, "Ta
X.insert(0, TS_TIMESTAMP_COL, ds_col)
if cat_columns:
X[cat_columns] = X[cat_columns].astype("category")
# Pin the per-column category list seen at fit time so
# `transform()` produces the same integer codes for the same
# values regardless of what is passed at predict time (see
# issue #1101). "__NAN__" is reserved as the sentinel slot
# used for values unseen at fit time.
self._cat_categories = {}
for col in cat_columns:
cats = list(X[col].cat.categories)
if "__NAN__" not in cats:
cats.append("__NAN__")
self._cat_categories[col] = cats
if num_columns:
X_num = X[num_columns]
try:
Expand Down Expand Up @@ -450,6 +462,30 @@ def transform(self, X: Union[DataFrame, np.array]):
X[column] = X[column].cat.add_categories("__NAN__").fillna("__NAN__")
if cat_columns:
X[cat_columns] = X[cat_columns].astype("category")
# Pin codes to the categories seen at fit time so they do not
# drift when the predict-time column has a different value
# distribution than the fit-time column (see issue #1101).
# Older pickles without `_cat_categories` fall back to
# whatever `astype("category")` inferred above.
saved_cats_map = getattr(self, "_cat_categories", None)
if saved_cats_map:
for column in cat_columns:
saved_cats = saved_cats_map.get(column)
if saved_cats is None:
continue
current = X[column].astype(object)
unseen_mask = ~current.isin(saved_cats) & current.notna()
if unseen_mask.any():
samples = sorted({str(v) for v in current[unseen_mask].unique()})[:5]
warnings.warn(
f"Column '{column}' contains values unseen at fit time "
f"(e.g. {samples}); these rows will be encoded as '__NAN__' "
"and predictions may be unreliable.",
UserWarning,
stacklevel=2,
)
current = current.where(~unseen_mask, "__NAN__")
X[column] = pd.Categorical(current, categories=saved_cats)
if num_columns:
X_num = X[num_columns].fillna(np.nan)
if self._drop:
Expand Down
60 changes: 60 additions & 0 deletions test/automl/test_preprocess_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,65 @@ def test_estimator_preprocess_without_automl(self):
self.assertEqual(X_preprocessed.shape, X_test.shape)


class TestCategoricalEncodingStability(unittest.TestCase):
"""Regression coverage for #1101 — DataTransformer must produce the same
integer code for the same categorical value regardless of which values
happen to be present in the predict-time DataFrame."""

def _fit_simple(self):
from flaml.automl.data import DataTransformer
from flaml.automl.task.factory import task_factory

rng = np.random.RandomState(0)
n = 100
fit_df = pd.DataFrame({"a": rng.randn(n), "gender": rng.choice(["M", "F"], n)})
fit_y = pd.Series(rng.randn(n))

transformer = DataTransformer()
task = task_factory("regression", fit_df, fit_y)
X_fit, _ = transformer.fit_transform(fit_df.copy(), fit_y, task)
return transformer, X_fit

def test_codes_stable_when_predict_uses_only_a_subset(self):
transformer, X_fit = self._fit_simple()
fit_code_for_M = int(X_fit["gender"].cat.codes[X_fit["gender"] == "M"].iloc[0])

# Predict-time DataFrame contains only "M" rows.
predict_df = pd.DataFrame({"a": np.zeros(20), "gender": ["M"] * 20})
X_pred = transformer.transform(predict_df.copy())
pred_code_for_M = int(X_pred["gender"].cat.codes[X_pred["gender"] == "M"].iloc[0])

self.assertEqual(
fit_code_for_M,
pred_code_for_M,
"categorical code for 'M' drifted between fit and predict — see #1101",
)

def test_unseen_categories_emit_warning_and_map_to_sentinel(self):
import warnings

transformer, _ = self._fit_simple()
predict_df = pd.DataFrame({"a": np.zeros(5), "gender": ["M", "F", "X", "M", "Y"]})

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
X_pred = transformer.transform(predict_df.copy())

unseen_warnings = [
w for w in caught if issubclass(w.category, UserWarning) and "unseen at fit time" in str(w.message)
]
self.assertEqual(len(unseen_warnings), 1)
message = str(unseen_warnings[0].message)
self.assertIn("gender", message)
self.assertIn("X", message)
self.assertIn("Y", message)

# Unseen "X" and "Y" rows must be encoded as the "__NAN__" sentinel slot
# and seen "F" / "M" codes must still match fit-time codes.
nan_code = list(X_pred["gender"].cat.categories).index("__NAN__")
unseen_rows = X_pred["gender"].cat.codes[predict_df["gender"].isin(["X", "Y"]).values]
self.assertTrue((unseen_rows == nan_code).all())


if __name__ == "__main__":
unittest.main()
Loading