Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions causalml/inference/meta/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
X (np.matrix): a feature matrix
tau (np.array): a treatment effect vector (estimated/actual)
classes (dict): a mapping of treatment names to indices (used for indexing tau array)
model_tau (sklearn/lightgbm/xgboost model object): a model object
model_tau (sklearn/lightgbm/xgboost/catboost model object): a model object
features (np.array): list/array of feature names. If None, an enumerated list will be used.
normalize (bool): normalize by sum of importances if method=auto (defaults to True)
test_size (float/int): if float, represents the proportion of the dataset to include in the test split.
Expand Down Expand Up @@ -84,7 +84,7 @@ def check_conditions(self):
Checks for multiple conditions:
- method is valid
- X, tau, and classes are specified
- model_tau has feature_importances_ attribute after fitting
- model_tau has feature_importances_ after fitting
"""
assert self.method in VALID_METHODS, "Current supported methods: {}".format(
", ".join(VALID_METHODS)
Expand All @@ -97,7 +97,8 @@ def check_conditions(self):
model_test = deepcopy(self.model_tau)
model_test.fit(
[[0], [1]], [0, 1]
) # Fit w/ dummy data to check for feature_importances_ below
) # Fit w/ dummy data to ensure feature importances are available

assert hasattr(
model_test, "feature_importances_"
), "model_tau must have the feature_importances_ method (after fitting)"
Expand Down
29 changes: 29 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from causalml.inference.meta import BaseDRLearner
from causalml.inference.meta import BaseDRRegressor
from causalml.inference.meta import BaseDRClassifier
from causalml.inference.meta.explainer import Explainer
from causalml.metrics import ape, auuc_score

from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION
Expand Down Expand Up @@ -640,6 +641,34 @@ def test_BaseRRegressor(generate_regression_data):
assert auuc["cate_p"] > 0.5


def test_explainer_auto_importance_catboost(generate_regression_data):
catboost = pytest.importorskip("catboost")

y, X, treatment, tau, b, e = generate_regression_data()

model_tau = catboost.CatBoostRegressor(
iterations=10,
verbose=False,
random_seed=RANDOM_SEED,
)

explainer = Explainer(
method="auto",
control_name=CONTROL_NAME,
X=X,
tau=tau,
classes={CONTROL_NAME: 0},
model_tau=model_tau,
)

importance = explainer.get_importance()

assert len(importance) == 1
assert CONTROL_NAME in importance
assert isinstance(importance[CONTROL_NAME], pd.Series)
assert len(importance[CONTROL_NAME]) == X.shape[1]


def test_BaseRLearner_without_p(generate_regression_data):
y, X, treatment, tau, b, e = generate_regression_data()

Expand Down