Skip to content
51 changes: 33 additions & 18 deletions causalml/inference/meta/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABCMeta, abstractmethod
import copy
import logging
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import clone
from sklearn.base import BaseEstimator, clone
from tqdm import tqdm

from causalml.inference.meta.explainer import Explainer
Expand All @@ -18,7 +17,9 @@ def _fit_bootstrap_clone(learner_template, X, treatment, y, p, seed, bootstrap_s
"""Module-level bootstrap helper for joblib pickling compatibility.

Args:
learner_template: an unfitted template to clone
learner_template: an *unfitted* learner to clone as a template.
Because BaseLearner now inherits BaseEstimator, ``clone(learner_template)``
produces a clean unfitted copy via ``get_params``/``set_params``.
X: feature matrix
treatment: treatment vector
y: outcome vector
Expand All @@ -34,12 +35,31 @@ def _fit_bootstrap_clone(learner_template, X, treatment, y, p, seed, bootstrap_s
treatment_b = treatment[idxs]
y_b = y[idxs]
p_b = {group: _p[idxs] for group, _p in p.items()} if p is not None else None
learner_b = clone(learner_template, safe=False)
learner_b = clone(learner_template) # safe=True works now via get_params/set_params
learner_b.fit(X=X_b, treatment=treatment_b, y=y_b, p=p_b)
return learner_b


class BaseLearner(metaclass=ABCMeta):
class BaseLearner(BaseEstimator, metaclass=ABCMeta):
"""Base class for all causalml meta-learners.

Inheriting ``sklearn.base.BaseEstimator`` gives every subclass:
* ``get_params`` / ``set_params`` for free (requires verbatim ``__init__``
argument storage — see scikit-learn conventions).
* ``sklearn.base.clone`` support without ``safe=False``.
* ``Pipeline`` / ``GridSearchCV`` compatibility.

Subclass contract
-----------------
* ``__init__`` **must** store every argument verbatim as ``self.<param> = param``.
No logic, no ``deepcopy``, no derived attributes.
* All model construction and validation moves to ``fit()``.
* ``fit()`` deepcopies the verbatim-stored arg before fitting, so ``self.learner``
(and related params) remain unfitted across repeated ``fit()`` calls — this is
the warm-start invariant that replaces the old ``_model_*_template`` mechanism.
* ``__repr__`` is inherited from ``BaseEstimator`` and reflects constructor params.
"""

@classmethod
@abstractmethod
def fit(self, X, treatment, y, p=None):
Expand Down Expand Up @@ -99,11 +119,6 @@ def bootstrap(self, X, treatment, y, p=None, size=10000, rng=None):
self.fit(X=X_b, treatment=treatment_b, y=y_b, p=p_b)
return self.predict(X=X, p=p)

def _unfitted_clone(self):
"""Return an unfitted copy for bootstrap refitting. Subclasses that hold fitted
sub-models should override to reset them to their unfitted templates."""
return clone(self, safe=False)

def fit_bootstrap_ensemble(
self,
X,
Expand All @@ -121,12 +136,11 @@ def fit_bootstrap_ensemble(
and stores them in self.bootstrap_models_. Used by predict(return_ci=True)
to compute percentile-based confidence intervals on new data without refitting.

This design follows EconML's BootstrapEstimator pattern — each bootstrap
clone is a full copy of the learner, making this method generic across all
meta-learners.

Note: storing N bootstrap clones can be memory-intensive for heavy base
learners. Monitor RAM for large n_bootstraps.
Because ``BaseLearner`` now inherits ``BaseEstimator``, ``clone(self)``
produces a clean unfitted copy via ``get_params``/``set_params``. The
warm-start invariant — that ``self.learner`` stays unfitted across calls —
is maintained by each ``fit()`` deepcopying the verbatim-stored constructor
arg before fitting it.

Args:
X (np.matrix or np.array or pd.Dataframe): a feature matrix
Expand All @@ -138,15 +152,16 @@ def fit_bootstrap_ensemble(
random_state (int, optional): random seed for reproducibility.
n_jobs (int, optional): number of parallel jobs. -1 uses all cores. Default: 1.
"""
# clone(self) is now a proper sklearn clone — unfitted and cheap.
unfitted_template = clone(self)

rng = np.random.RandomState(random_state)
seeds = rng.randint(0, np.iinfo(np.int32).max, size=n_bootstraps)
logger.info("Storing bootstrap ensemble ({} iterations)".format(n_bootstraps))

learner_template = self._unfitted_clone()
self.bootstrap_models_ = Parallel(n_jobs=n_jobs)(
delayed(_fit_bootstrap_clone)(
learner_template, X, treatment, y, p, s, bootstrap_size
unfitted_template, X, treatment, y, p, s, bootstrap_size
)
for s in tqdm(seeds)
)
Expand Down
Loading