英文:
How to map class id to class name in summary plot's legend
问题
我在鸢尾花数据集上拟合了一个随机森林分类器,如下所示:
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 将 X, y 分为训练和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)
model = RandomForestClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
然后以以下方式绘制 shap 值总结:
import shap
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
输出:
我想将图例中的 class
ID 映射到类名,如下所示:
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
这样就能实现:Class 0 -> setosa,Class 1 -> versicolor,Class 2 -> virginica
。
英文:
I fitted a random forest classifier on the iris dataset like so:
iris = datasets.load_iris()
X = iris.data
y = iris.target
# dividing X, y into train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)
model = RandomForestClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
Then plot the shap values summary this way:
import shap
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
Output:
I want to map the class
ids to the class name in the plot legend:
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
Such that: Class 0 -> setosa, Class 1 -> versicolor, Class 2 -> virginica
答案1
得分: 1
你可以在shap.summary_plot()
函数中使用class_names
参数。
首先按照以下方式获取类名:
class_names = iris.target_names
然后将其传递给summary_plot()
函数:
shap.summary_plot(shap_values, X_test, class_names=class_names)
英文:
You can just use the class_names
parameter in the shap.summary_plot()
function.
Start by getting the class name like this:
class_names = iris.target_names
and then pass it to the summary_plot()
shap.summary_plot(shap_values, X_test, class_names=class_names)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论