如何将使用stable-baselines3创建的A2C模型导出到PyTorch?

huangapple go评论65阅读模式
英文:

How to export an A2C model created with stable-baselines3 to PyTorch?

问题

我已经使用stable-baselines3训练了一个A2C模型(MlpPolicy),但我现在想使用XRL(可解释强化学习)方法来更好地理解这个模型。我决定使用DeepSHAP,因为它有一个很好的实现,而且我对SHAP比较熟悉。DeepSHAP在PyTorch上运行,而stable-baselines3的底层框架也是PyTorch。因此,我的目标是从stable-baselines3模型中提取出PyTorch模型。然而,我在这方面遇到了一些问题。

我找到了以下线程:https://github.com/hill-a/stable-baselines/issues/372
这个线程对我有所帮助,但由于A2C的架构与此线程中使用的模型不同,我还没有能够解决我的问题。

据我所了解,stable-baselines3提供了使用以下方式导出模型的选项:

model.policy.state_dict()

但是,我在导入通过这种方法导出的内容时遇到了困难。

当打印出

A2C_model.policy

我可以看到PyTorch模型的结构大致是什么样子的。输出如下:

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=49, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=49, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=5, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)

我尝试自己重新创建它,但我对PyTorch还不够熟悉,无法让它正常工作。

所以我的问题是:我如何将stable-baselines3模型导出到PyTorch?

我已经尝试根据打印A2C_model.policy的输出在PyTorch中重新构建模型架构。我的代码目前如下:

import torch as th
import torch.nn as nn

class PyTorchMlp(nn.Module):  
        def __init__(self):
                nn.Module.__init__(self)

                n_inputs = 49
                n_actions = 5
        
                self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.mlp_extractor = nn.Sequentail(
                    self.policy_net = nn.Sequential(
                        nn.Linear(in_features = n_inputs, out_features = 64),
                        nn.Tanh(),
                        nn.Linear(in_features = 64, out_features = 64),
                        nn.Tanh()
                    ),
        
                    self.value_net = nn.Sequential(
                        nn.Linear(in_features = n_inputs, out_features = 64),
                        nn.Tanh(),
                        nn.Linear(in_features = 64, out_features = 64),
                        nn.Tanh()
                    )
                )
        
                self.action_net = nn.Linear(in_features = 64, out_features = 5)
        
                self.value_net = nn.Linear(in_features = 64, out_features = 1)
        
    
            def forward(self, x):
                pass
英文:

I have trained an A2C model (MlpPolicy) using stable-baselines3 (I am quite new to reinforcement learning and found this to be a good place to start).
However, I now want to use a XRL (eXplainable Reinforcement Learning) method to understand the model better. I decided to use DeepSHAP as it has a nice implementation and because I am familiar with SHAP.
DeepSHAP works on PyTorch, which is the underlying framework behind stable-baselines3. So my goal is to extract the underlying PyTorch model from the stable-baselines3 model. However, I am having some issues with this.

I have found the following thread: https://github.com/hill-a/stable-baselines/issues/372
This thread did help me a bit, however, because the architecture of A2C is different from the model used in this thread, I was not yet able to solve my problem.

From what I understand stable-baselines3 offers the option to export models using

model.policy.state_dict()

However, I am struggling to import what I have exported through that method.

When printing out

A2C_model.policy

I get a glimpse of what the structure of the PyTorch model looks like. Output:

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=49, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=49, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=5, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)

I tried to recreate it myself but I am not fluent enough with PyTorch yet to get it work...

So my question is: how can I export the stable_baselines3 model to PyTorch?

I have tried re-building the model architecture in PyTorch according to the output of printing A2C_model.policy. My code is currently:

import torch as th
import torch.nn as nn

class PyTorchMlp(nn.Module):  
        def __init__(self):
                nn.Module.__init__(self)

                n_inputs = 49
                n_actions = 5
        
                self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.mlp_extractor = nn.Sequentail(
                    self.policy_net = nn.Sequential(
                        nn.Linear(in_features = n_inputs, out_features = 64),
                        nn.Tanh(),
                        nn.Linear(in_features = 64, out_features = 64),
                        nn.Tanh()
                    ),
        
                    self.value_net = nn.Sequential(
                        nn.Linear(in_features = n_inputs, out_features = 64),
                        nn.Tanh(),
                        nn.Linear(in_features = 64, out_features = 64),
                        nn.Tanh()
                    )
                )
        
                self.action_net = nn.Linear(in_features = 64, out_features = 5)
        
                self.value_net = nn.Linear(in_features = 64, out_features = 1)
        
    
            def forward(self, x):
                pass

答案1

得分: 1

以下是您要翻译的内容:

"如果您只想将其导出为一个PyTorch模型,以便使用Shap框架中的DeepExplainer,您只需要创建一个类来包装模型的 policy_netaction_net。我的解决方案是基于stable-baselines3的PPO(MLP)模型实现的,但我确信对于A2C来说不会有太大不同。

我的PPO(MLP)模型的包装器类:

import shap
import torch
import torch.nn as nn
from stable_baselines3 import PPO

class sb3Wrapper(nn.Module):
    def __init__(self, model):
        super(sb3Wrapper,self).__init__()
        self.extractor = model.policy.mlp_extractor
        self.policy_net = model.policy.mlp_extractor.policy_net
        self.action_net = model.policy.action_net

    def forward(self,x):
        x = self.policy_net(x)
        x = self.action_net(x)
        return x

关于Shap框架的深度解释器,您需要确保以下几点:

  1. 您需要确保模型和传递给 DeepExplainer 函数的状态数据(作为PyTorch张量)位于相同的设备上(即 'cuda'/'cpu')。
  2. 如果您的状态数据是连续的,请确保使用 torch.FloatTensor() 函数。

以下是我实现的一些代码,希望能帮助您:

(我提取并存储了所有的数据在一个数据帧中,因为我还在执行其他分析)

model = PPO.load(model_path, device='cuda')
state_log = np.array(df['observation'].values.tolist())
data = torch.FloatTensor(state_log).to('cuda')
model = sb3Wrapper(model)
explainer = shap.DeepExplainer(model, data)
shap_vals= explainer.shap_values(data)

参考资料和有用的链接:

  • Shap框架的Github问题,其中一个人想要使用stable_baselines3的DQN与DeepExplainer一起使用。
  • stable-baselines3的Github问题,关于如何访问网络的每一层。
    • 如果您想逐层包装模型,这可能对您有所帮助(尽管我尚未尝试过)。
  • Kaggle代码,其中包含有关在PyTorch模型上使用 DeepExplainer 的有用信息。"
英文:

If you only want to export it as a pytorch model for the purposes of using the DeepExplainer from the shap framework, all you need to do is create a class to wrap the models' policy_net and the action_net together. My solution was implemented stable-baselines3's PPO (MLP) model but I'm sure it wont be to different for the A2C.

My Wrapper class for my PPO (MLP) model:

import shap
import torch
import torch.nn as nn
from stable_baselines3 import PPO

class sb3Wrapper(nn.Module):
    def __init__(self, model):
        super(sb3Wrapper,self).__init__()
        self.extractor = model.policy.mlp_extractor
        self.policy_net = model.policy.mlp_extractor.policy_net
        self.action_net = model.policy.action_net

    def forward(self,x):
        x = self.policy_net(x)
        x = self.action_net(x)
        return x

With regards to the shap frameworks deep explainer you are going to need to make sure of a few things

  1. You need to make sure the that both the model and the state data (as a torch tensor) you are passing into the DeepExplainer function are on the same device (i.e. 'cuda'/'cpu')
  2. If your state data is is continuous, make sure you use the torch.FloatTensor() function

Here is a few lines from my implementation to help you along the way:
(I extracted and stored all my data in a dataframe because I'm also performing other analyses)

model = PPO.load(model_path, device='cuda')
state_log = np.array(df['observation'].values.tolist())
data = torch.FloatTensor(state_log).to('cuda')
model = sb3Wrapper(model)
explainer = shap.DeepExplainer(model, data)
shap_vals= explainer.shap_values(data)

References and helpful links:

  • Shap framework Github issue where the person wanted to use stable_baselines3 DQN with the DeepExplainer
  • stable-baselines3 Github issue on accessing every layer of a network
    • This might help you if you want to wrap the model layer by layer (I haven't tried that out tho)
  • Kaggle code that had useful information on using the DeepExplainer on pytorch models.

huangapple
  • 本文由 发表于 2023年6月5日 17:02:50
  • 转载请务必保留本文链接:https://go.coder-hub.com/76404906.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定