diff --git a/requirements.txt b/requirements.txt index 5bb2a0460..df75a19ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ seaborn tqdm rich importlib-metadata~=4.13 # flake8 not compatible with importlib-metadata>5.0 +envpool diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 4503121b0..1ab64a63d 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -37,6 +37,7 @@ SubprocVecEnv, VecEnv, VecFrameStack, + VecMonitor, VecNormalize, VecTransposeImage, is_vecenv_wrapped, @@ -51,6 +52,13 @@ from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule +try: + import envpool + + from rl_zoo3.vec_env_wrappers import EnvPoolAdapter +except ImportError: + envpool = None + class ExperimentManager: """ @@ -97,6 +105,7 @@ def __init__( device: Union[th.device, str] = "auto", config: Optional[str] = None, show_progress: bool = False, + use_envpool: bool = False, ): super().__init__() self.algo = algo @@ -120,6 +129,7 @@ def __init__( self.seed = seed self.optimization_log_path = optimization_log_path + self.use_envpool = use_envpool self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type] self.vec_env_wrapper = None @@ -581,37 +591,52 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) # Do not log eval env (issue with writing the same file) log_dir = None if eval_env or no_log else self.save_path - # Special case for GoalEnvs: log success rate too - if ( - "Neck" in self.env_name.gym_id - or self.is_robotics_env(self.env_name.gym_id) - or "parking-v0" in self.env_name.gym_id - and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs - ): - self.monitor_kwargs = dict(info_keywords=("is_success",)) - - # Define make_env here so it works with subprocesses - # when the registry was modified with `--gym-packages` - # See https://github.com/HumanCompatibleAI/imitation/pull/160 - spec = gym.spec(self.env_name.gym_id) - - def make_env(**kwargs) -> gym.Env: - env = spec.make(**kwargs) - return env - - # On most env, SubprocVecEnv does not help and is quite memory hungry - # therefore we use DummyVecEnv by default - env = make_vec_env( - make_env, - n_envs=n_envs, - seed=self.seed, - env_kwargs=self.env_kwargs, - monitor_dir=log_dir, - wrapper_class=self.env_wrapper, - vec_env_cls=self.vec_env_class, - vec_env_kwargs=self.vec_env_kwargs, - monitor_kwargs=self.monitor_kwargs, - ) + if self.use_envpool: + # TODO: warning if env wrapper is passed + # Convert Atari game names + # See https://github.com/sail-sg/envpool/issues/14 + env_id = self.env_name.gym_id + if self._is_atari and "NoFrameskip-v4" in env_id: + env_id = env_id.split("NoFrameskip-v4")[0] + "-v5" + + env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed) + env.spec.id = self.env_name.gym_id + env = EnvPoolAdapter(env) + filename = None if log_dir is None else f"{log_dir}/monitor.csv" + env = VecMonitor(env, filename, **self.monitor_kwargs) + + else: + # Special case for GoalEnvs: log success rate too + if ( + "Neck" in self.env_name.gym_id + or self.is_robotics_env(self.env_name.gym_id) + or "parking-v0" in self.env_name.gym_id + and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs + ): + self.monitor_kwargs = dict(info_keywords=("is_success",)) + + # Define make_env here so it works with subprocesses + # when the registry was modified with `--gym-packages` + # See https://github.com/HumanCompatibleAI/imitation/pull/160 + spec = gym.spec(self.env_name.gym_id) + + def make_env(**kwargs) -> gym.Env: + env = spec.make(**kwargs) + return env + + # On most env, SubprocVecEnv does not help and is quite memory hungry + # therefore we use DummyVecEnv by default + env = make_vec_env( + make_env, + n_envs=n_envs, + seed=self.seed, + env_kwargs=self.env_kwargs, + monitor_dir=log_dir, + wrapper_class=self.env_wrapper, + vec_env_cls=self.vec_env_class, + vec_env_kwargs=self.vec_env_kwargs, + monitor_kwargs=self.monitor_kwargs, + ) if self.vec_env_wrapper is not None: env = self.vec_env_wrapper(env) diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index 1e52a5fc0..53a55bc3e 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -153,6 +153,13 @@ def train() -> None: default=False, help="if toggled, display a progress bar using tqdm and rich", ) + parser.add_argument( + "-envpool", + "--use-envpool", + action="store_true", + default=False, + help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.", + ) parser.add_argument( "-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123" ) @@ -183,7 +190,7 @@ def train() -> None: uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" if args.seed < 0: # Seed but with a random one - args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined] + args.seed = np.random.randint(2**31 - 1, dtype="int64").item() # type: ignore[attr-defined] set_random_seed(args.seed) @@ -259,6 +266,7 @@ def train() -> None: device=args.device, config=args.conf_file, show_progress=args.progress, + use_envpool=args.use_envpool, ) # Prepare experiment and launch hyperparameter optimization if needed diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py new file mode 100644 index 000000000..cc16ddaef --- /dev/null +++ b/rl_zoo3/vec_env_wrappers.py @@ -0,0 +1,46 @@ +from typing import Optional + +import gym +import numpy as np +from envpool.python.protocol import EnvPool +from stable_baselines3.common.vec_env import VecEnvWrapper +from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn + + +class EnvPoolAdapter(VecEnvWrapper): + """ + Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. + + :param venv: The envpool object. + """ + + def __init__(self, venv: EnvPool): + # Retrieve the number of environments from the config + venv.num_envs = venv.spec.config.num_envs + super().__init__(venv=venv) + # Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 + if isinstance(self.action_space, gym.spaces.Box) and self.action_space.dtype == np.float64: + self.action_space.dtype = np.dtype(np.float32) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def reset(self) -> VecEnvObs: + return self.venv.reset() + + def seed(self, seed: Optional[int] = None) -> None: + # You can only seed EnvPool env by calling envpool.make() + pass + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, info_dict = self.venv.step(self.actions) + infos = [] + # Convert dict to list of dict + # and add terminal observation + for i in range(self.num_envs): + infos.append({key: info_dict[key][i] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) + if dones[i]: + infos[i]["terminal_observation"] = obs[i] + obs[i] = self.venv.reset(np.array([i])) + + return obs, rewards, dones, infos