英文:
Stable-Baselines make_vec_env() not calling the wrapper kwargs as expected
问题
以下是您要翻译的内容:
"我有一个自定义的健身环境,我正在实现动作屏蔽。下面的代码效果很好
from toyEnv_mask import PortfolioEnv # file with my custom gym environment
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.common.maskable.utils import get_action_masks
env = PortfolioEnv()
env = ActionMasker(env,action_mask_fn=PortfolioEnv.action_masks)
env = DummyVecEnv([lambda: env])
model = MaskablePPO("MlpPolicy", env, gamma=0.4, verbose=1, tensorboard_log="./MPPO_tensorboard/")
...
我想要并行化这个过程,我正在尝试使用make_vec_env()
类。这部分设置好并开始运行了,但它不再尊重动作屏蔽。以下是我用来替换上面初始化环境的3行代码的行:
env = make_vec_env(PortfolioEnv, wrapper_class=ActionMasker,wrapper_kwargs={'action_mask_fn':PortfolioEnv.action_masks} ,n_envs=2)
任何建议/帮助将不胜感激。"
英文:
I have a custom gym environment where I am implementing action masking. The code below works well
from toyEnv_mask import PortfolioEnv # file with my custom gym environment
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.common.maskable.utils import get_action_masks
env = PortfolioEnv()
env = ActionMasker(env,action_mask_fn=PortfolioEnv.action_masks)
env = DummyVecEnv([lambda: env])
model = MaskablePPO("MlpPolicy", env, gamma=0.4, verbose=1, tensorboard_log="./MPPO_tensorboard/")
...
i want to parallelize this and I am trying to use the make_vec_env() class. This sets up ok and starts running, but it does not respect the action masks anymore. This is the line i am using to replace the 3 lines above where i initialize the env.
env = make_vec_env(PortfolioEnv, wrapper_class=ActionMasker,wrapper_kwargs={'action_mask_fn':PortfolioEnv.action_masks} ,n_envs=2)
Any suggestions / help would be greatly appreciated.
答案1
得分: 0
I was making it unnecessarily complicated.
A simple env = make_vec_env(PortfolioEnv, n_envs=16)
works.
And the 3 lines that were working previously could have been just:
env = PortfolioEnv()
env = DummyVecEnv([lambda: env])
英文:
i was making it unnecessarily complicated.
a simple env = make_vec_env(PortfolioEnv,n_envs=16)
works
and the 3 lines that were working previously could have been just
env = PortfolioEnv()
env = DummyVecEnv([lambda: env])
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论