shap.plots.bar() 图的子图?

huangapple go评论91阅读模式
英文:

Subplots for shap.plots.bar() plots?

问题

我想将 `shap.plots.bar`(https://github.com/slundberg/shap图形添加到子图中类似于这样...

```python
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,20))
for (X, y) in [(x1, y1), (x2, y2)]:
  model = xgboost.XGBRegressor().fit(X, y)
  explainer = shap.Explainer(model, check_additivity=False)
  shap_values = explainer(X, check_additivity=False)
  shap.plots.bar(shap_values, max_display=6, show=False) # ax=ax ?? 
plt.show()

然而,对于 shap.plots.barax 未定义,不像其他绘图方法,如 shap.dependence_plot(..., ax=ax[0, 0], show=False)。是否有办法将多个条形图添加到子图中?


<details>
<summary>英文:</summary>

I&#39;d like to add the `shap.plots.bar` (https://github.com/slundberg/shap) figure to a subplot. Something like this...

```python
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,20))
for (X, y) in [(x1, y1), (x2, y2)]:
  model = xgboost.XGBRegressor().fit(X, y)
  explainer = shap.Explainer(model, check_additivity=False)
  shap_values = explainer(X, check_additivity=False)
  shap.plots.bar(shap_values, max_display=6, show=False) # ax=ax ?? 
plt.show()

However, ax is undefined for shap.plots.bar, unlike some other plotting methods such as shap.dependence_plot(..., ax=ax[0, 0], show=False). Is there a way to add many bar plots to a subplot?

答案1

得分: 1

查看源代码,该函数不会创建自己的图形。因此,您可以创建一个图形,然后使用plt.sca将所需的坐标轴设置为当前坐标轴。

以下是使用文档中的条形图示例代码来执行此操作的方法。

import xgboost
import shap
import matplotlib.pyplot as plt

X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)

explainer = shap.Explainer(model, X)
shap_values = explainer(X)

fig, (ax1, ax2) = plt.subplots(2, 1)
plt.sca(ax2)
shap.plots.bar(shap_values)
fig.tight_layout()
fig.show()

shap.plots.bar() 图的子图?

如果您确实想要使用ax参数,您需要编辑源代码以添加该选项。

英文:

Looking at the source code, the function does not create it's own figure. So, you can create a figure and then set the desired axis as the current axis using plt.sca.

Here is how you'd do it using the bar plot sample code from the documentation.

import xgboost
import shap
import matplotlib.pyplot as plt

X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)

explainer = shap.Explainer(model, X)
shap_values = explainer(X)


fig, (ax1, ax2) = plt.subplots(2, 1)
plt.sca(ax2)
shap.plots.bar(shap_values)
fig.tight_layout()
fig.show()

shap.plots.bar() 图的子图?

If you really want to have the ax argument, you'll have to edit the source code to add that option.

huangapple
  • 本文由 发表于 2023年7月6日 22:32:20
  • 转载请务必保留本文链接:https://go.coder-hub.com/76629925.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定