Sorry, your browser cannot access this site
This page requires browser support (enable) JavaScript
Learn more >

https://blog.csdn.net/wangxiancao/article/details/123487557

报错:

Traceback (most recent call last):
  File "view_adaboost.py", line 84, in <module>
    y_test, y_pred, y_probas = train_eval_adaboost()
  File "view_adaboost.py", line 64, in train_eval_adaboost
    explainer = shap.TreeExplainer(rf)
  File "/home/yuanzhiqiu/.local/lib/python3.8/site-packages/shap/explainers/_tree.py", line 175, in __init__
    self.model = TreeEnsemble(model, self.data, self.data_missing, model_output)
  File "/home/yuanzhiqiu/.local/lib/python3.8/site-packages/shap/explainers/_tree.py", line 1226, in __init__
    raise InvalidModelError("Model type not yet supported by TreeExplainer: " + str(type(model)))
shap.utils._exceptions.InvalidModelError: Model type not yet supported by TreeExplainer: <class 'sklearn.ensemble._weight_boosting.AdaBoostClassifier'>

从中知道要修改的源文件:/home/yuanzhiqiu/.local/lib/python3.8/site-packages/shap/explainers/_tree.py,以及要修改的位置大概在第1226行上边
添加代码:

elif safe_isinstance(model, ["sklearn.ensemble.AdaBoostClassifier", "sklearn.ensemble._weighted_boosting.AdaBoostClassifier"]):
    assert hasattr(model, "estimators_"), "Model has no `estimators_`! Have you called `model.fit`?"
    self.internal_dtype = model.estimators_[0].tree_.value.dtype.type
    self.input_dtype = np.float32
    scaling = 1.0 / len(model.estimators_) # output is average of trees
    self.trees = [SingleTree(e.tree_, normalize=True, scaling=scaling, data=data, data_missing=data_missing) for e in model.estimators_]
    self.objective = objective_name_map.get(model.estimator_.criterion, None) #This line is done to get the decision criteria, for example gini.
    self.tree_output = "probability" #This is the last line added

然后就可以啦

import shap
rf = AdaBoostClassifier()
rf = rf.fit(X_train, y_train)

explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test)  
shap.summary_plot(shap_values[1], X_test, show=False, feature_names=["sex", "age", "weight"]) # show=False for save 
plt.savefig("shap.pdf", format="pdf")

评论