Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions causalml/inference/meta/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def fit(
bootstrap_size (int, optional): number of samples per bootstrap. Default: 10000.
random_state (int, optional): random seed for reproducible bootstrap sampling.
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
treatment, y = convert_pd_to_np(treatment, y)
check_treatment_vector(treatment, self.control_name)
self.t_groups = np.unique(treatment[treatment != self.control_name])
self.t_groups.sort()
Expand All @@ -129,13 +129,23 @@ def fit(
# re-calling fit() always starts from a clean state (safe with warm_start).
control_mask = treatment == self.control_name
self.model_c = deepcopy(self._model_c_template)
self.model_c.fit(X[control_mask], y[control_mask])
X_control = (
X[control_mask].reset_index(drop=True)
if hasattr(X, "loc")
else X[control_mask]
)
self.model_c.fit(X_control, y[control_mask])
# Expose as a shared-reference dict to preserve the public models_c API.
self.models_c = {group: self.model_c for group in self.t_groups}

for group in self.t_groups:
treatment_mask = treatment == group
self.models_t[group].fit(X[treatment_mask], y[treatment_mask])
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
X_filt = X[mask].reset_index(drop=True) if hasattr(X, "loc") else X[mask]
y_filt = y[mask]
w = (treatment_filt == group).astype(int)

self.models_t[group].fit(X_filt[w == 1], y_filt[w == 1])

if store_bootstraps:
self.fit_bootstrap_ensemble(
Expand Down Expand Up @@ -202,7 +212,7 @@ def predict(
if return_ci and return_components:
raise ValueError("return_ci and return_components cannot both be True.")

X, treatment, y = convert_pd_to_np(X, treatment, y)
treatment, y = convert_pd_to_np(treatment, y)
yhat_ts = {}

yhat_c = self.model_c.predict(X)
Expand Down Expand Up @@ -267,7 +277,7 @@ def fit_predict(
If return_ci, returns CATE [n_samples, n_treatment], LB [n_samples, n_treatment],
UB [n_samples, n_treatment]
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
treatment, y = convert_pd_to_np(treatment, y)
self.fit(X, treatment, y)
te = self.predict(X, treatment, y, return_components=return_components)

Expand Down Expand Up @@ -325,7 +335,7 @@ def estimate_ate(
The mean and confidence interval (LB, UB) of the ATE estimate.
pretrain (bool): whether a model has been fit, default False.
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
treatment, y = convert_pd_to_np(treatment, y)
if pretrain:
te, yhat_cs, yhat_ts = self.predict(X, treatment, y, return_components=True)
else:
Expand Down
9 changes: 8 additions & 1 deletion causalml/inference/meta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@


def convert_pd_to_np(*args):
output = [obj.to_numpy() if hasattr(obj, "to_numpy") else obj for obj in args]
def _convert(obj):
if isinstance(obj, pd.DataFrame) and any(
pd.api.types.is_categorical_dtype(obj[c]) for c in obj.columns
):
return obj # pass through so learners can handle categoricals natively
return obj.to_numpy() if hasattr(obj, "to_numpy") else obj

output = [_convert(obj) for obj in args]
return output if len(output) > 1 else output[0]


Expand Down
40 changes: 40 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,46 @@ def test_BaseDRClassifier(generate_classification_data):
assert te_separate.shape == te.shape


def test_BaseTLearner_with_categorical_features():
np.random.seed(RANDOM_SEED)
n = 200

X = pd.DataFrame(
{
"num1": np.random.randn(n),
"num2": np.random.randn(n),
"cat1": pd.Categorical(np.random.choice([0, 1, 2], size=n)),
}
)
treatment = np.random.binomial(1, 0.5, n)
y = X["num1"].values + (treatment * 0.5) + np.random.randn(n) * 0.1

learner = BaseTRegressor(learner=XGBRegressor(enable_categorical=True))
learner.fit(X=X, treatment=treatment, y=y)
te = learner.predict(X=X)

assert te.shape == (n, 1)


def test_BaseRLearner_with_categorical_features():
np.random.seed(RANDOM_SEED)
n = 200

X = pd.DataFrame(
{
"num1": np.random.randn(n),
"num2": np.random.randn(n),
"cat1": pd.Categorical(np.random.choice([0, 1, 2], size=n)),
}
)
treatment = np.random.binomial(1, 0.5, n)
y = X["num1"].values + (treatment * 0.5) + np.random.randn(n) * 0.1

learner = BaseRRegressor(learner=XGBRegressor(enable_categorical=True))
learner.fit(X=X, treatment=treatment, y=y)
te = learner.predict(X=X)

assert te.shape == (n, 1)
def test_BaseTLearner_predict_return_ci(generate_regression_data):
y, X, treatment, tau, b, e = generate_regression_data()

Expand Down