diff --git a/causalml/inference/meta/explainer.py b/causalml/inference/meta/explainer.py index b92fee86..b156b027 100644 --- a/causalml/inference/meta/explainer.py +++ b/causalml/inference/meta/explainer.py @@ -47,7 +47,7 @@ def __init__( X (np.matrix): a feature matrix tau (np.array): a treatment effect vector (estimated/actual) classes (dict): a mapping of treatment names to indices (used for indexing tau array) - model_tau (sklearn/lightgbm/xgboost model object): a model object + model_tau (sklearn/lightgbm/xgboost/catboost model object): a model object features (np.array): list/array of feature names. If None, an enumerated list will be used. normalize (bool): normalize by sum of importances if method=auto (defaults to True) test_size (float/int): if float, represents the proportion of the dataset to include in the test split. @@ -84,7 +84,7 @@ def check_conditions(self): Checks for multiple conditions: - method is valid - X, tau, and classes are specified - - model_tau has feature_importances_ attribute after fitting + - model_tau has feature_importances_ after fitting """ assert self.method in VALID_METHODS, "Current supported methods: {}".format( ", ".join(VALID_METHODS) @@ -97,7 +97,8 @@ def check_conditions(self): model_test = deepcopy(self.model_tau) model_test.fit( [[0], [1]], [0, 1] - ) # Fit w/ dummy data to check for feature_importances_ below + ) # Fit w/ dummy data to ensure feature importances are available + assert hasattr( model_test, "feature_importances_" ), "model_tau must have the feature_importances_ method (after fitting)" diff --git a/tests/test_meta_learners.py b/tests/test_meta_learners.py index 2efb77b0..ad0dcd7b 100644 --- a/tests/test_meta_learners.py +++ b/tests/test_meta_learners.py @@ -36,6 +36,7 @@ from causalml.inference.meta import BaseDRLearner from causalml.inference.meta import BaseDRRegressor from causalml.inference.meta import BaseDRClassifier +from causalml.inference.meta.explainer import Explainer from causalml.metrics import ape, auuc_score from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION @@ -640,6 +641,34 @@ def test_BaseRRegressor(generate_regression_data): assert auuc["cate_p"] > 0.5 +def test_explainer_auto_importance_catboost(generate_regression_data): + catboost = pytest.importorskip("catboost") + + y, X, treatment, tau, b, e = generate_regression_data() + + model_tau = catboost.CatBoostRegressor( + iterations=10, + verbose=False, + random_seed=RANDOM_SEED, + ) + + explainer = Explainer( + method="auto", + control_name=CONTROL_NAME, + X=X, + tau=tau, + classes={CONTROL_NAME: 0}, + model_tau=model_tau, + ) + + importance = explainer.get_importance() + + assert len(importance) == 1 + assert CONTROL_NAME in importance + assert isinstance(importance[CONTROL_NAME], pd.Series) + assert len(importance[CONTROL_NAME]) == X.shape[1] + + def test_BaseRLearner_without_p(generate_regression_data): y, X, treatment, tau, b, e = generate_regression_data()