英文:
Subplots for shap.plots.bar() plots?
问题
我想将 `shap.plots.bar`(https://github.com/slundberg/shap)图形添加到子图中。类似于这样...
```python
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,20))
for (X, y) in [(x1, y1), (x2, y2)]:
model = xgboost.XGBRegressor().fit(X, y)
explainer = shap.Explainer(model, check_additivity=False)
shap_values = explainer(X, check_additivity=False)
shap.plots.bar(shap_values, max_display=6, show=False) # ax=ax ??
plt.show()
然而,对于 shap.plots.bar
,ax
未定义,不像其他绘图方法,如 shap.dependence_plot(..., ax=ax[0, 0], show=False)
。是否有办法将多个条形图添加到子图中?
<details>
<summary>英文:</summary>
I'd like to add the `shap.plots.bar` (https://github.com/slundberg/shap) figure to a subplot. Something like this...
```python
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,20))
for (X, y) in [(x1, y1), (x2, y2)]:
model = xgboost.XGBRegressor().fit(X, y)
explainer = shap.Explainer(model, check_additivity=False)
shap_values = explainer(X, check_additivity=False)
shap.plots.bar(shap_values, max_display=6, show=False) # ax=ax ??
plt.show()
However, ax
is undefined for shap.plots.bar
, unlike some other plotting methods such as shap.dependence_plot(..., ax=ax[0, 0], show=False)
. Is there a way to add many bar plots to a subplot?
答案1
得分: 1
查看源代码,该函数不会创建自己的图形。因此,您可以创建一个图形,然后使用plt.sca
将所需的坐标轴设置为当前坐标轴。
以下是使用文档中的条形图示例代码来执行此操作的方法。
import xgboost
import shap
import matplotlib.pyplot as plt
X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
fig, (ax1, ax2) = plt.subplots(2, 1)
plt.sca(ax2)
shap.plots.bar(shap_values)
fig.tight_layout()
fig.show()
如果您确实想要使用ax
参数,您需要编辑源代码以添加该选项。
英文:
Looking at the source code, the function does not create it's own figure. So, you can create a figure and then set the desired axis as the current axis using plt.sca
.
Here is how you'd do it using the bar plot sample code from the documentation.
import xgboost
import shap
import matplotlib.pyplot as plt
X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
fig, (ax1, ax2) = plt.subplots(2, 1)
plt.sca(ax2)
shap.plots.bar(shap_values)
fig.tight_layout()
fig.show()
If you really want to have the ax
argument, you'll have to edit the source code to add that option.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论