Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ seaborn
tqdm
rich
importlib-metadata~=4.13 # flake8 not compatible with importlib-metadata>5.0
envpool
87 changes: 56 additions & 31 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
SubprocVecEnv,
VecEnv,
VecFrameStack,
VecMonitor,
VecNormalize,
VecTransposeImage,
is_vecenv_wrapped,
Expand All @@ -51,6 +52,13 @@
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule

try:
import envpool

from rl_zoo3.vec_env_wrappers import EnvPoolAdapter
except ImportError:
envpool = None


class ExperimentManager:
"""
Expand Down Expand Up @@ -97,6 +105,7 @@ def __init__(
device: Union[th.device, str] = "auto",
config: Optional[str] = None,
show_progress: bool = False,
use_envpool: bool = False,
):
super().__init__()
self.algo = algo
Expand All @@ -120,6 +129,7 @@ def __init__(
self.seed = seed
self.optimization_log_path = optimization_log_path

self.use_envpool = use_envpool
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
self.vec_env_wrapper = None

Expand Down Expand Up @@ -581,37 +591,52 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
# Do not log eval env (issue with writing the same file)
log_dir = None if eval_env or no_log else self.save_path

# Special case for GoalEnvs: log success rate too
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
):
self.monitor_kwargs = dict(info_keywords=("is_success",))

# Define make_env here so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
spec = gym.spec(self.env_name.gym_id)

def make_env(**kwargs) -> gym.Env:
env = spec.make(**kwargs)
return env

# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
env = make_vec_env(
make_env,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
vec_env_kwargs=self.vec_env_kwargs,
monitor_kwargs=self.monitor_kwargs,
)
if self.use_envpool:
# TODO: warning if env wrapper is passed
# Convert Atari game names
# See https://github.com/sail-sg/envpool/issues/14
env_id = self.env_name.gym_id
if self._is_atari and "NoFrameskip-v4" in env_id:
env_id = env_id.split("NoFrameskip-v4")[0] + "-v5"

env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed)
env.spec.id = self.env_name.gym_id
env = EnvPoolAdapter(env)
filename = None if log_dir is None else f"{log_dir}/monitor.csv"
env = VecMonitor(env, filename, **self.monitor_kwargs)

else:
# Special case for GoalEnvs: log success rate too
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
):
self.monitor_kwargs = dict(info_keywords=("is_success",))

# Define make_env here so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
spec = gym.spec(self.env_name.gym_id)

def make_env(**kwargs) -> gym.Env:
env = spec.make(**kwargs)
return env

# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
env = make_vec_env(
make_env,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
vec_env_kwargs=self.vec_env_kwargs,
monitor_kwargs=self.monitor_kwargs,
)

if self.vec_env_wrapper is not None:
env = self.vec_env_wrapper(env)
Expand Down
10 changes: 9 additions & 1 deletion rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def train() -> None:
default=False,
help="if toggled, display a progress bar using tqdm and rich",
)
parser.add_argument(
"-envpool",
"--use-envpool",
action="store_true",
default=False,
help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.",
)
parser.add_argument(
"-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123"
)
Expand Down Expand Up @@ -183,7 +190,7 @@ def train() -> None:
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
if args.seed < 0:
# Seed but with a random one
args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined]
args.seed = np.random.randint(2**31 - 1, dtype="int64").item() # type: ignore[attr-defined]

set_random_seed(args.seed)

Expand Down Expand Up @@ -259,6 +266,7 @@ def train() -> None:
device=args.device,
config=args.conf_file,
show_progress=args.progress,
use_envpool=args.use_envpool,
)

# Prepare experiment and launch hyperparameter optimization if needed
Expand Down
46 changes: 46 additions & 0 deletions rl_zoo3/vec_env_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

import gym
import numpy as np
from envpool.python.protocol import EnvPool
from stable_baselines3.common.vec_env import VecEnvWrapper
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn


class EnvPoolAdapter(VecEnvWrapper):
"""
Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.

:param venv: The envpool object.
"""

def __init__(self, venv: EnvPool):
# Retrieve the number of environments from the config
venv.num_envs = venv.spec.config.num_envs
super().__init__(venv=venv)
# Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145
if isinstance(self.action_space, gym.spaces.Box) and self.action_space.dtype == np.float64:
self.action_space.dtype = np.dtype(np.float32)

def step_async(self, actions: np.ndarray) -> None:
self.actions = actions

def reset(self) -> VecEnvObs:
return self.venv.reset()

def seed(self, seed: Optional[int] = None) -> None:
# You can only seed EnvPool env by calling envpool.make()
pass

def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
infos = []
# Convert dict to list of dict
# and add terminal observation
for i in range(self.num_envs):
infos.append({key: info_dict[key][i] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)})
if dones[i]:
infos[i]["terminal_observation"] = obs[i]
obs[i] = self.venv.reset(np.array([i]))

return obs, rewards, dones, infos