使用稳定的baseline3创建自定义策略

huangapple go评论57阅读模式
英文:

Create Custom Policy using stable baseline3

问题

我正在尝试创建一个自定义的LSTM策略。似乎缺少BasePolicy。我们如何创建一个自定义的LSTM策略以传递给PPO或A2C算法。此外,如果不行,是否可以修改当前设置中LSTM层,以帮助自定义我的结果。

import gym
import torch.nn as nn
from sb3_contrib.common.policies import BasePolicy
from sb3_contrib.ppo_recurrent import RecurrentPPO

class CustomPolicy(BasePolicy):
    def __init__(self, *args, **kwargs):
        super(CustomPolicy, self).__init__(*args, **kwargs)

    def make_lstm_layer(self, n_lstm_layers: int) -> nn.Module:
        # 在这里创建您的自定义LSTM层
        # 例如:
        lstm_layer = nn.LSTM(input_size=self.features_dim, hidden_size=64, num_layers=n_lstm_layers)
        return lstm_layer

env_name = "CartPole-v1"
env = gym.make(env_name)
model = RecurrentPPO(CustomPolicy, env, verbose=1)

请指导。

我尝试安装这些包,但无法弄清楚如何创建自定义策略或修改他们提供的策略。

英文:

I am trying to create a custom lstm policy. It seems that BasePolicy is missing. How can we create a custom LSTM policy to pass to PPO or A2C algorithm. Also, if not, can modify the layer of lstm in the current setting that will help in customizing my results.

import gym
import torch.nn as nn
from sb3_contrib.common.policies import BasePolicy
from sb3_contrib.ppo_recurrent import RecurrentPPO


class CustomPolicy(BasePolicy):
    def __init__(self, *args, **kwargs):
        super(CustomPolicy, self).__init__(*args, **kwargs)

    def make_lstm_layer(self, n_lstm_layers: int) -> nn.Module:
        # Create your custom LSTM layer here
        # For example:
        lstm_layer = nn.LSTM(input_size=self.features_dim, hidden_size=64, num_layers=n_lstm_layers)
        return lstm_layer


env_name = "CartPole-v1"
env = gym.make(env_name)
model = RecurrentPPO(CustomPolicy, env, verbose=1)

Please guide.

I tried installing the packages but was unable to figure out how to create the custom policies or modify the policies they have provided

答案1

得分: 1

LSTM-policies不受sb3默认支持,但它们受到来自sb3-contribRecurrentPPO支持。

英文:

LSTM-policies are not supported by sb3 out-of-the-box, but they are supported by RecurrentPPO from sb3-contrib.

huangapple
  • 本文由 发表于 2023年3月8日 14:29:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/75669986.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定