Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions causalml/inference/meta/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def predict(
(numpy.ndarray): Predictions of treatment effects.
"""
X = collect_if_lazy(X)

X_new_c = prepend_column(0.0, X)
X_new_t = prepend_column(1.0, X)

yhat_cs = {}
yhat_ts = {}

Expand All @@ -135,8 +139,6 @@ def predict(
# Build separate frames for control and treatment to avoid in-place
# mutation, which fails when learners like CatBoost set the
# writeable flag to False on arrays passed to predict().
X_new_c = prepend_column(0.0, X)
X_new_t = prepend_column(1.0, X)
yhat_cs[group] = model.predict(X_new_c)
yhat_ts[group] = model.predict(X_new_t)

Expand Down Expand Up @@ -373,6 +375,10 @@ def predict(
(numpy.ndarray): Predictions of treatment effects.
"""
X = collect_if_lazy(X)

X_new_c = prepend_column(0.0, X)
X_new_t = prepend_column(1.0, X)

yhat_cs = {}
yhat_ts = {}

Expand All @@ -382,8 +388,6 @@ def predict(
# Build separate frames for control and treatment to avoid in-place
# mutation, which fails when learners like CatBoost set the
# writeable flag to False on arrays passed to predict().
X_new_c = prepend_column(0.0, X)
X_new_t = prepend_column(1.0, X)
yhat_cs[group] = model.predict_proba(X_new_c)[:, 1]
yhat_ts[group] = model.predict_proba(X_new_t)[:, 1]

Expand Down
10 changes: 8 additions & 2 deletions causalml/inference/meta/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def fit(self, X, treatment, y, p=None):
y_filt_np = to_numpy(y_filt)

# Train treatment outcome model
self.models_mu_t[group].fit(X_filt_t, filter_mask(y_filt, w == 1))
self.models_mu_t[group].fit(
X_filt_t,
y_filt_np[w == 1],
)

var_t = (
y_filt_np[w == 1] - self.models_mu_t[group].predict(X_filt_t)
Expand Down Expand Up @@ -589,7 +592,10 @@ def fit(self, X, treatment, y, p=None):
y_filt_np = to_numpy(y_filt)

# Train treatment outcome model
self.models_mu_t[group].fit(X_filt_t, filter_mask(y_filt, w == 1))
self.models_mu_t[group].fit(
X_filt_t,
y_filt_np[w == 1],
)

var_t = (
y_filt_np[w == 1]
Expand Down
83 changes: 73 additions & 10 deletions tests/test_polars_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from sklearn.linear_model import LinearRegression


from sklearn.linear_model import LogisticRegression
from causalml.inference.meta.tlearner import BaseTClassifier
from causalml.inference.meta.slearner import BaseSClassifier
Expand All @@ -23,20 +24,24 @@
from causalml.inference.meta.slearner import BaseSRegressor
from causalml.inference.meta.xlearner import BaseXRegressor
from causalml.inference.meta.rlearner import BaseRRegressor
from causalml.inference.meta.drlearner import BaseDRRegressor
from causalml.inference.meta.drlearner import (
BaseDRRegressor,
BaseDRClassifier,
)
from causalml.inference.meta.utils import convert_pd_to_np, check_p_conditions

from tests.const import RANDOM_SEED

# Fixtures

N = 200
N_FEATURES = 5
RANDOM_STATE = 42


@pytest.fixture(scope="module")
def synthetic_data_numpy():
"""Return (X, treatment, y) as NumPy arrays — the baseline."""
rng = np.random.default_rng(RANDOM_STATE)
rng = np.random.default_rng(RANDOM_SEED)
X = rng.standard_normal((N, N_FEATURES))
treatment = rng.choice([0, 1], size=N)
y = X[:, 0] * treatment + rng.standard_normal(N) * 0.1
Expand Down Expand Up @@ -73,7 +78,7 @@ def synthetic_data_polars_lazy(synthetic_data_polars):
@pytest.fixture(scope="module")
def synthetic_data_numpy_binary():
"""Return (X, treatment, y) with binary y for classifier tests."""
rng = np.random.default_rng(RANDOM_STATE)
rng = np.random.default_rng(RANDOM_SEED)
X = rng.standard_normal((N, N_FEATURES))
treatment = rng.choice([0, 1], size=N)
y = rng.choice([0, 1], size=N)
Expand Down Expand Up @@ -212,6 +217,38 @@ def test_fit_predict_returns_numpy(self, synthetic_data_polars):
te = self._fit_predict(*synthetic_data_polars)
assert isinstance(te, np.ndarray)

@pytest.mark.parametrize(
"data_fixture",
[
"synthetic_data_pandas",
"synthetic_data_polars",
],
)
def test_bootstrap_ci_dataframe(self, request, data_fixture):
X, treatment, y = request.getfixturevalue(data_fixture)

learner = BaseTRegressor(
learner=LinearRegression(),
control_name=0,
)

learner.fit(
X,
treatment,
y,
store_bootstraps=True,
n_bootstraps=5,
bootstrap_size=100,
random_state=RANDOM_SEED,
)

te, lb, ub = learner.predict(X, return_ci=True)

assert te.shape == (N, len(learner.t_groups))
assert lb.shape == te.shape
assert ub.shape == te.shape
assert np.all(lb <= ub)

def test_estimate_ate_polars(self, synthetic_data_polars):
X, treatment, y = synthetic_data_polars
ate, lb, ub = self.learner.estimate_ate(X, treatment, y)
Expand Down Expand Up @@ -280,21 +317,21 @@ def test_fit_predict_returns_numpy(self, synthetic_data_polars):
class TestRLearnerPolars:
@pytest.fixture(autouse=True)
def _learner(self):
# fixed random_state so KFold splits are identical across both runs
# fixed RANDOM_SEED so KFold splits are identical across both runs
self.learner = BaseRRegressor(
learner=LinearRegression(), random_state=RANDOM_STATE
learner=LinearRegression(), random_state=RANDOM_SEED
)

def _fit_predict(self, X, treatment, y):
self.learner.fit(X, treatment, y)
return self.learner.predict(X)

def test_polars_matches_numpy(self, synthetic_data_numpy, synthetic_data_polars):
# With a fixed random_state the KFold splits are deterministic,
# With a fixed RANDOM_SEED the KFold splits are deterministic,
# so numpy and polars inputs must produce identical results.
te_np = self._fit_predict(*synthetic_data_numpy)
self.learner = BaseRRegressor(
learner=LinearRegression(), random_state=RANDOM_STATE
learner=LinearRegression(), random_state=RANDOM_SEED
)
te_pl = self._fit_predict(*synthetic_data_polars)
_assert_te_close(te_np, te_pl)
Expand All @@ -319,9 +356,9 @@ def _fit_predict(self, X, treatment, y, seed=None):
def test_polars_matches_numpy(self, synthetic_data_numpy, synthetic_data_polars):
# DR-Learner uses KFold with a seed parameter passed to fit(); fix it
# so both runs use the same splits.
te_np = self._fit_predict(*synthetic_data_numpy, seed=RANDOM_STATE)
te_np = self._fit_predict(*synthetic_data_numpy, seed=RANDOM_SEED)
self.learner = BaseDRRegressor(learner=LinearRegression())
te_pl = self._fit_predict(*synthetic_data_polars, seed=RANDOM_STATE)
te_pl = self._fit_predict(*synthetic_data_polars, seed=RANDOM_SEED)
_assert_te_close(te_np, te_pl)

def test_fit_predict_returns_numpy(self, synthetic_data_polars):
Expand Down Expand Up @@ -461,3 +498,29 @@ def test_lazyframe_input(
def test_fit_predict_returns_numpy(self, synthetic_data_polars_binary):
te = self._fit_predict(*synthetic_data_polars_binary)
assert isinstance(te, np.ndarray)


class TestDRClassifierPolars:
@pytest.fixture(autouse=True)
def _learner(self):
self.learner = BaseDRClassifier(
learner=LogisticRegression(),
treatment_effect_learner=LinearRegression(),
)

def _fit_predict(self, X, treatment, y):
self.learner.fit(
X=X,
treatment=treatment,
y=y,
)
return self.learner.predict(X)

def test_fit_predict_returns_numpy(
self,
synthetic_data_polars_binary,
):
te = self._fit_predict(*synthetic_data_polars_binary)

assert isinstance(te, np.ndarray)
assert te.shape == (N, len(self.learner.t_groups))