如何在汇总图例中将类别ID映射到类别名称

huangapple go评论78阅读模式
英文:

How to map class id to class name in summary plot's legend

问题

我在鸢尾花数据集上拟合了一个随机森林分类器,如下所示:

iris = datasets.load_iris()
X = iris.data
y = iris.target

# 将 X, y 分为训练和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)

model = RandomForestClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)

然后以以下方式绘制 shap 值总结:

import shap

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

shap.summary_plot(shap_values, X_test)

输出:

如何在汇总图例中将类别ID映射到类别名称

我想将图例中的 class ID 映射到类名,如下所示:

iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

这样就能实现:Class 0 -> setosa,Class 1 -> versicolor,Class 2 -> virginica

英文:

I fitted a random forest classifier on the iris dataset like so:

iris = datasets.load_iris()
X = iris.data
y = iris.target

# dividing X, y into train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)

model = RandomForestClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)

Then plot the shap values summary this way:

import shap

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

shap.summary_plot(shap_values, X_test)

Output:

如何在汇总图例中将类别ID映射到类别名称

I want to map the class ids to the class name in the plot legend:

iris.target_names
array([&#39;setosa&#39;, &#39;versicolor&#39;, &#39;virginica&#39;], dtype=&#39;&lt;U10&#39;)

Such that: Class 0 -&gt; setosa, Class 1 -&gt; versicolor, Class 2 -&gt; virginica

答案1

得分: 1

你可以在shap.summary_plot()函数中使用class_names参数。

首先按照以下方式获取类名:

class_names = iris.target_names

然后将其传递给summary_plot()函数:

shap.summary_plot(shap_values, X_test, class_names=class_names)
英文:

You can just use the class_names parameter in the shap.summary_plot() function.

Start by getting the class name like this:

class_names = iris.target_names

and then pass it to the summary_plot()

shap.summary_plot(shap_values, X_test, class_names=class_names)

huangapple
  • 本文由 发表于 2023年7月3日 23:22:29
  • 转载请务必保留本文链接:https://go.coder-hub.com/76606132.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定