Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions vbridge/modeling/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from sklearn.utils import class_weight
from xgboost import XGBClassifier

from pyreal.explainers import ShapFeatureContribution
from pyreal.utils.transformer import DataFrameWrapper

classification_metrics = {
'Accuracy': sklearn.metrics.accuracy_score,
'F1 Macro': lambda y_true, y_pred: sklearn.metrics.f1_score(y_true, y_pred, average="macro"),
Expand All @@ -37,8 +40,8 @@ def test(model, X, y):
class Modeler:
def __init__(self, topk=10, **kwargs):
self._one_hot_encoder = OneHotEncoder(topk=topk)
self._imputer = SimpleImputer()
self._scaler = MinMaxScaler()
self._imputer = DataFrameWrapper(SimpleImputer())
self._scaler = DataFrameWrapper(MinMaxScaler())
self._model = XGBClassifier(use_label_encoder=False, **kwargs)
self._explainer = None

Expand Down Expand Up @@ -68,8 +71,10 @@ def fit(self, X, y, eval_set=None, target='complication', explain=True):
y_train = y[target].values

X_train = self._one_hot_encoder.fit_transform(X)
X_train = self._imputer.fit_transform(X_train)
X_train = self._scaler.fit_transform(X_train)
self._imputer.fit(X_train)
X_train = self._imputer.transform(X_train)
self._scaler.fit(X_train)
X_train = self._scaler.transform(X_train)

if eval_set:
X_eval, y_eval = eval_set
Expand All @@ -83,9 +88,10 @@ def fit(self, X, y, eval_set=None, target='complication', explain=True):
sample_weight = [weights[l] for l in y_train]
self._model.fit(X_train, y_train, sample_weight=sample_weight, eval_metric='auc',
eval_set=eval_set, early_stopping_rounds=10, verbose=False)
# self._model.fit(X_train, y_train)
if explain:
self._explainer = shap.TreeExplainer(self._model)
transforms = [self._one_hot_encoder, self._imputer, self._scaler]
self._explainer = ShapFeatureContribution(self._model, X, transforms=transforms,
fit_on_init=True)

def transform(self, X):
X = self._one_hot_encoder.transform(X)
Expand All @@ -101,21 +107,8 @@ def test(self, X, y, target='complication'):
return test(self.model, X_test, y_test)

def SHAP(self, X):
columns = X.columns
X = self._one_hot_encoder.transform(X)
X = self._imputer.transform(X)
X = self._scaler.transform(X)
dummy_columns = self._one_hot_encoder.dummy_columns

# print('SHAP', X, dummy_columns)
shap_values = pd.DataFrame(self._explainer.shap_values(X), columns=dummy_columns)
for original_col, dummies in self._one_hot_encoder.dummy_dict.items():
sub_dummy_column = ["{}_{}".format(original_col, cat) for cat in dummies]
assert np.array([col in dummy_columns for col in sub_dummy_column]).all()
if original_col + "_Others" in dummy_columns:
sub_dummy_column.append(original_col + "_Others")
shap_values[original_col] = shap_values.loc[:, sub_dummy_column].sum(axis=1)
return shap_values.reindex(columns=columns)
contributions = self._explainer.produce(X)
return contributions


class OneHotEncoder(TransformerMixin):
Expand Down