Stable-Baselines的make_vec_env()未按预期调用包装器的关键字参数。

huangapple go评论54阅读模式
英文:

Stable-Baselines make_vec_env() not calling the wrapper kwargs as expected

问题

以下是您要翻译的内容:

"我有一个自定义的健身环境,我正在实现动作屏蔽。下面的代码效果很好

from toyEnv_mask import PortfolioEnv # file with my custom gym environment
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.common.maskable.utils import get_action_masks

env = PortfolioEnv()
env = ActionMasker(env,action_mask_fn=PortfolioEnv.action_masks)
env = DummyVecEnv([lambda: env])

model = MaskablePPO("MlpPolicy", env, gamma=0.4, verbose=1, tensorboard_log="./MPPO_tensorboard/")
...

我想要并行化这个过程,我正在尝试使用make_vec_env()类。这部分设置好并开始运行了,但它不再尊重动作屏蔽。以下是我用来替换上面初始化环境的3行代码的行:

env = make_vec_env(PortfolioEnv, wrapper_class=ActionMasker,wrapper_kwargs={'action_mask_fn':PortfolioEnv.action_masks} ,n_envs=2)

任何建议/帮助将不胜感激。"

英文:

I have a custom gym environment where I am implementing action masking. The code below works well

from toyEnv_mask import PortfolioEnv # file with my custom gym environment
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.common.maskable.utils import get_action_masks



env = PortfolioEnv()
env = ActionMasker(env,action_mask_fn=PortfolioEnv.action_masks)
env = DummyVecEnv([lambda: env])

model = MaskablePPO("MlpPolicy", env, gamma=0.4, verbose=1, tensorboard_log="./MPPO_tensorboard/")
...

i want to parallelize this and I am trying to use the make_vec_env() class. This sets up ok and starts running, but it does not respect the action masks anymore. This is the line i am using to replace the 3 lines above where i initialize the env.

env = make_vec_env(PortfolioEnv, wrapper_class=ActionMasker,wrapper_kwargs={'action_mask_fn':PortfolioEnv.action_masks} ,n_envs=2)

Any suggestions / help would be greatly appreciated.

答案1

得分: 0

I was making it unnecessarily complicated.

A simple env = make_vec_env(PortfolioEnv, n_envs=16) works.

And the 3 lines that were working previously could have been just:

env = PortfolioEnv()
env = DummyVecEnv([lambda: env])
英文:

i was making it unnecessarily complicated.

a simple env = make_vec_env(PortfolioEnv,n_envs=16) works

and the 3 lines that were working previously could have been just

env = PortfolioEnv()
env = DummyVecEnv([lambda: env])

huangapple
  • 本文由 发表于 2023年2月8日 23:36:02
  • 转载请务必保留本文链接:https://go.coder-hub.com/75388139.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定