英文:
How to export an A2C model created with stable-baselines3 to PyTorch?
问题
我已经使用stable-baselines3训练了一个A2C模型(MlpPolicy),但我现在想使用XRL(可解释强化学习)方法来更好地理解这个模型。我决定使用DeepSHAP,因为它有一个很好的实现,而且我对SHAP比较熟悉。DeepSHAP在PyTorch上运行,而stable-baselines3的底层框架也是PyTorch。因此,我的目标是从stable-baselines3模型中提取出PyTorch模型。然而,我在这方面遇到了一些问题。
我找到了以下线程:https://github.com/hill-a/stable-baselines/issues/372
这个线程对我有所帮助,但由于A2C的架构与此线程中使用的模型不同,我还没有能够解决我的问题。
据我所了解,stable-baselines3提供了使用以下方式导出模型的选项:
model.policy.state_dict()
但是,我在导入通过这种方法导出的内容时遇到了困难。
当打印出
A2C_model.policy
我可以看到PyTorch模型的结构大致是什么样子的。输出如下:
ActorCriticPolicy(
(features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(pi_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(vf_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(mlp_extractor): MlpExtractor(
(policy_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=5, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)
我尝试自己重新创建它,但我对PyTorch还不够熟悉,无法让它正常工作。
所以我的问题是:我如何将stable-baselines3模型导出到PyTorch?
我已经尝试根据打印A2C_model.policy的输出在PyTorch中重新构建模型架构。我的代码目前如下:
import torch as th
import torch.nn as nn
class PyTorchMlp(nn.Module):
def __init__(self):
nn.Module.__init__(self)
n_inputs = 49
n_actions = 5
self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.mlp_extractor = nn.Sequentail(
self.policy_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
),
self.value_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
)
)
self.action_net = nn.Linear(in_features = 64, out_features = 5)
self.value_net = nn.Linear(in_features = 64, out_features = 1)
def forward(self, x):
pass
英文:
I have trained an A2C model (MlpPolicy) using stable-baselines3 (I am quite new to reinforcement learning and found this to be a good place to start).
However, I now want to use a XRL (eXplainable Reinforcement Learning) method to understand the model better. I decided to use DeepSHAP as it has a nice implementation and because I am familiar with SHAP.
DeepSHAP works on PyTorch, which is the underlying framework behind stable-baselines3. So my goal is to extract the underlying PyTorch model from the stable-baselines3 model. However, I am having some issues with this.
I have found the following thread: https://github.com/hill-a/stable-baselines/issues/372
This thread did help me a bit, however, because the architecture of A2C is different from the model used in this thread, I was not yet able to solve my problem.
From what I understand stable-baselines3 offers the option to export models using
model.policy.state_dict()
However, I am struggling to import what I have exported through that method.
When printing out
A2C_model.policy
I get a glimpse of what the structure of the PyTorch model looks like. Output:
ActorCriticPolicy(
(features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(pi_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(vf_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(mlp_extractor): MlpExtractor(
(policy_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=5, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)
I tried to recreate it myself but I am not fluent enough with PyTorch yet to get it work...
So my question is: how can I export the stable_baselines3 model to PyTorch?
I have tried re-building the model architecture in PyTorch according to the output of printing A2C_model.policy. My code is currently:
import torch as th
import torch.nn as nn
class PyTorchMlp(nn.Module):
def __init__(self):
nn.Module.__init__(self)
n_inputs = 49
n_actions = 5
self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.mlp_extractor = nn.Sequentail(
self.policy_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
),
self.value_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
)
)
self.action_net = nn.Linear(in_features = 64, out_features = 5)
self.value_net = nn.Linear(in_features = 64, out_features = 1)
def forward(self, x):
pass
答案1
得分: 1
以下是您要翻译的内容:
"如果您只想将其导出为一个PyTorch模型,以便使用Shap框架中的DeepExplainer,您只需要创建一个类来包装模型的 policy_net
和 action_net
。我的解决方案是基于stable-baselines3的PPO(MLP)模型实现的,但我确信对于A2C来说不会有太大不同。
我的PPO(MLP)模型的包装器类:
import shap
import torch
import torch.nn as nn
from stable_baselines3 import PPO
class sb3Wrapper(nn.Module):
def __init__(self, model):
super(sb3Wrapper,self).__init__()
self.extractor = model.policy.mlp_extractor
self.policy_net = model.policy.mlp_extractor.policy_net
self.action_net = model.policy.action_net
def forward(self,x):
x = self.policy_net(x)
x = self.action_net(x)
return x
关于Shap框架的深度解释器,您需要确保以下几点:
- 您需要确保模型和传递给
DeepExplainer
函数的状态数据(作为PyTorch张量)位于相同的设备上(即 'cuda'/'cpu')。 - 如果您的状态数据是连续的,请确保使用
torch.FloatTensor()
函数。
以下是我实现的一些代码,希望能帮助您:
(我提取并存储了所有的数据在一个数据帧中,因为我还在执行其他分析)
model = PPO.load(model_path, device='cuda')
state_log = np.array(df['observation'].values.tolist())
data = torch.FloatTensor(state_log).to('cuda')
model = sb3Wrapper(model)
explainer = shap.DeepExplainer(model, data)
shap_vals= explainer.shap_values(data)
参考资料和有用的链接:
- Shap框架的Github问题,其中一个人想要使用stable_baselines3的DQN与DeepExplainer一起使用。
- stable-baselines3的Github问题,关于如何访问网络的每一层。
- 如果您想逐层包装模型,这可能对您有所帮助(尽管我尚未尝试过)。
- Kaggle代码,其中包含有关在PyTorch模型上使用
DeepExplainer
的有用信息。"
英文:
If you only want to export it as a pytorch model for the purposes of using the DeepExplainer from the shap framework, all you need to do is create a class to wrap the models' policy_net
and the action_net
together. My solution was implemented stable-baselines3's PPO (MLP) model but I'm sure it wont be to different for the A2C.
My Wrapper class for my PPO (MLP) model:
import shap
import torch
import torch.nn as nn
from stable_baselines3 import PPO
class sb3Wrapper(nn.Module):
def __init__(self, model):
super(sb3Wrapper,self).__init__()
self.extractor = model.policy.mlp_extractor
self.policy_net = model.policy.mlp_extractor.policy_net
self.action_net = model.policy.action_net
def forward(self,x):
x = self.policy_net(x)
x = self.action_net(x)
return x
With regards to the shap frameworks deep explainer you are going to need to make sure of a few things
- You need to make sure the that both the model and the state data (as a torch tensor) you are passing into the
DeepExplainer
function are on the same device (i.e. 'cuda'/'cpu') - If your state data is is continuous, make sure you use the
torch.FloatTensor()
function
Here is a few lines from my implementation to help you along the way:
(I extracted and stored all my data in a dataframe because I'm also performing other analyses)
model = PPO.load(model_path, device='cuda')
state_log = np.array(df['observation'].values.tolist())
data = torch.FloatTensor(state_log).to('cuda')
model = sb3Wrapper(model)
explainer = shap.DeepExplainer(model, data)
shap_vals= explainer.shap_values(data)
References and helpful links:
- Shap framework Github issue where the person wanted to use stable_baselines3 DQN with the DeepExplainer
- stable-baselines3 Github issue on accessing every layer of a network
- This might help you if you want to wrap the model layer by layer (I haven't tried that out tho)
- Kaggle code that had useful information on using the
DeepExplainer
on pytorch models.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论