Skip to content

[Question] recommended way to convert gymnasium/gym v26 envs to SB3 VecEnvs  #1356

@elliottower

Description

@elliottower

❓ Question

Posting this here to not spam the Gymnasium integration PR (#1327) as afaik it’s just a use case question rather than an issue with the PR. Will edit with example code to make things more clear but I mainly just want to know the best practices for converting envs with step() functions returning truncated and terminated bools into SB3’s API using done signals.

I would like to make vector envs but I run into issues due to the differing number of return types (5 vs 4). My initial thought was to ignore truncation and set done to equal termination, but reading discussions and documentation it seems like it’s best to set done equal to truncated or terminated. PR comments here say to use a TimeLimit wrapper as well, to capture the truncation signal. Is this then the best practice?

done                        = terminated or truncated
info["TimeLimit.truncated"] = not terminated and truncated

Example code of wrapping the env with this TimeLimit wrapper and doing this conversion would be greatly appreciated.

Relevant references:
https://github.com/DLR-RM/stable-baselines3/blob/feat/gymnasium-support/docs/guide/vec_envs.rst
openai/gym#3102 (comment)
https://gymnasium.farama.org/content/migration-guide/
#780 (comment)

Edit: a bit more context for what my issue was (converting the step function):
#1327 (comment)

Full code below: sb3_train.py (updating older training script with older pettingzoo using gym rather than gymnasium):

"""Binary to run Stable Baselines 3 agents on meltingpot substrates."""

import gymnasium
import stable_baselines3
from stable_baselines3.common import callbacks
from stable_baselines3.common import torch_layers
from stable_baselines3.common import vec_env
from rl_zoo3.gym_patches import PatchedTimeLimit
# from sb3_contrib.common import vec_env # only has async env

import supersuit as ss
import torch
from torch import nn
import torch.nn.functional as F

from examples.pettingzoo import utils
from meltingpot.python import substrate

device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")


# Use this with lambda wrapper returning observations only
class CustomCNN(torch_layers.BaseFeaturesExtractor):
  """Class describing a custom feature extractor."""

  def __init__(
      self,
      observation_space: gymnasium.spaces.Box,
      features_dim=128,
      num_frames=6,
      fcnet_hiddens=(1024, 128),
  ):
    """Construct a custom CNN feature extractor.

    Args:
      observation_space: the observation space as a gymnasium.Space
      features_dim: Number of features extracted. This corresponds to the number
        of unit for the last layer.
      num_frames: The number of (consecutive) frames to feed into the network.
      fcnet_hiddens: Sizes of hidden layers.
    """
    super(CustomCNN, self).__init__(observation_space, features_dim)
    # We assume CxHxW images (channels first)
    # Re-ordering will be done by pre-preprocessing or wrapper

    self.conv = nn.Sequential(
        nn.Conv2d(
            num_frames * 3, num_frames * 3, kernel_size=8, stride=4, padding=0),
        nn.ReLU(),  # 18 * 21 * 21
        nn.Conv2d(
            num_frames * 3, num_frames * 6, kernel_size=5, stride=2, padding=0),
        nn.ReLU(),  # 36 * 9 * 9
        nn.Conv2d(
            num_frames * 6, num_frames * 6, kernel_size=3, stride=1, padding=0),
        nn.ReLU(),  # 36 * 7 * 7
        nn.Flatten(),
    )
    flat_out = num_frames * 6 * 7 * 7
    self.fc1 = nn.Linear(in_features=flat_out, out_features=fcnet_hiddens[0])
    self.fc2 = nn.Linear(
        in_features=fcnet_hiddens[0], out_features=fcnet_hiddens[1])

  def forward(self, observations) -> torch.Tensor:
    # Convert to tensor, rescale to [0, 1], and convert from
    #   B x H x W x C to B x C x H x W
    observations = observations.permute(0, 3, 1, 2)
    features = self.conv(observations)
    features = F.relu(self.fc1(features))
    features = F.relu(self.fc2(features))
    return features


def main():
  # Config
  substrate_name = "commons_harvest__open"
  player_roles = substrate.get_config(substrate_name).default_player_roles
  env_config = {"substrate": substrate_name, "roles": player_roles}

  env = utils.parallel_env(env_config)
  rollout_len = 1000
  total_timesteps = 2000000
  num_agents = env.max_num_agents

  max_steps = 1000
  max_steps_eval = 1000

  # Training
  num_cpus = 1  # number of cpus
  num_envs = 1  # number of parallel multi-agent environments
  # number of frames to stack together; use >4 to avoid automatic
  # VecTransposeImage
  num_frames = 4
  # output layer of cnn extractor AND shared layer for policy and value
  # functions
  features_dim = 128
  fcnet_hiddens = [1024, 128]  # Two hidden layers for cnn extractor
  ent_coef = 0.001  # entropy coefficient in loss
  batch_size = (rollout_len * num_envs // 2
               )  # This is from the rllib baseline implementation
  lr = 0.0001
  n_epochs = 30
  gae_lambda = 1.0
  gamma = 0.99
  target_kl = 0.01
  grad_clip = 40
  verbose = 3
  model_path = None  # Replace this with a saved model

  env = utils.parallel_env(
      max_cycles=rollout_len,
      env_config=env_config,
  )
  env = ss.observation_lambda_v0(env, lambda x, _: x["RGB"], lambda s: s["RGB"])
  env = ss.dtype_v0(env, "uint8")
  env = ss.pettingzoo_env_to_vec_env_v1(env)
  env = ss.concat_vec_envs_v1(
      env,
      num_vec_envs=num_envs,
      num_cpus=num_cpus,
      base_class="stable_baselines3")
  env = PatchedTimeLimit(env, max_steps)
  env = vec_env.VecTransposeImage(env, True)
  env = vec_env.VecFrameStack(env, num_frames)

  eval_env = utils.parallel_env(
      max_cycles=rollout_len,
      env_config=env_config,
  )
  eval_env = ss.observation_lambda_v0(eval_env, lambda x, _: x["RGB"],
                                      lambda s: s["RGB"])
  eval_env = ss.dtype_v0(eval_env, "uint8")
  eval_env = ss.pettingzoo_env_to_vec_env_v1(eval_env)
  eval_env = ss.concat_vec_envs_v1(
      eval_env,
      num_vec_envs=1,
      num_cpus=1,
      base_class="stable_baselines3")
  eval_env = PatchedTimeLimit(eval_env, max_steps_eval)
  eval_env = vec_env.VecTransposeImage(eval_env, True)
  eval_env = vec_env.VecFrameStack(eval_env, num_frames)
  eval_freq = 100000 // (num_envs * num_agents)

  policy_kwargs = dict(
      features_extractor_class=CustomCNN,
      features_extractor_kwargs=dict(
          features_dim=features_dim,
          num_frames=num_frames,
          fcnet_hiddens=fcnet_hiddens,
      ),
      net_arch=[features_dim],
  )

  tensorboard_log = "./results/sb3/harvest_open_ppo_paramsharing"

  model = stable_baselines3.PPO(
      "CnnPolicy",
      env=env,
      learning_rate=lr,
      n_steps=rollout_len,
      batch_size=batch_size,
      n_epochs=n_epochs,
      gamma=gamma,
      gae_lambda=gae_lambda,
      ent_coef=ent_coef,
      max_grad_norm=grad_clip,
      target_kl=target_kl,
      policy_kwargs=policy_kwargs,
      tensorboard_log=tensorboard_log,
      verbose=verbose,
  )
  if model_path is not None:
    model = stable_baselines3.PPO.load(model_path, env=env)
  eval_callback = callbacks.EvalCallback(
      eval_env, eval_freq=eval_freq, best_model_save_path=tensorboard_log)
  model.learn(total_timesteps=total_timesteps, callback=eval_callback)

  logdir = model.logger.dir
  model.save(logdir + "/model")
  del model
  model = stable_baselines3.PPO.load(logdir + "/model")  # noqa: F841


if __name__ == "__main__":
  main()

Utils helper file (also updating original script with old pettingzoo/gym rather than gymnasium):

"""PettingZoo interface to meltingpot environments."""

import functools

from gymnasium import utils as gym_utils
import matplotlib.pyplot as plt
from ml_collections import config_dict
from pettingzoo import utils as pettingzoo_utils
from pettingzoo.utils import wrappers

from examples import utils
from meltingpot.python import substrate

PLAYER_STR_FORMAT = 'player_{index}'
MAX_CYCLES = 1000


def parallel_env(env_config, max_cycles=MAX_CYCLES):
  return _ParallelEnv(env_config, max_cycles)


def raw_env(env_config, max_cycles=MAX_CYCLES):
  return pettingzoo_utils.parallel_to_aec_wrapper(
      parallel_env(env_config, max_cycles))


def env(env_config, max_cycles=MAX_CYCLES):
  aec_env = raw_env(env_config, max_cycles)
  aec_env = wrappers.AssertOutOfBoundsWrapper(aec_env)
  aec_env = wrappers.OrderEnforcingWrapper(aec_env)
  return aec_env


class _MeltingPotPettingZooEnv(pettingzoo_utils.ParallelEnv):
  """An adapter between Melting Pot substrates and PettingZoo's ParallelEnv."""

  def __init__(self, env_config, max_cycles):
    self.env_config = config_dict.ConfigDict(env_config)
    self.max_cycles = max_cycles
    self._env = substrate.build(env_config['substrate'], roles=env_config['roles'])
    self._num_players = len(self._env.observation_spec())
    self.possible_agents = [
        PLAYER_STR_FORMAT.format(index=index)
        for index in range(self._num_players)
    ]
    self.agents = [agent for agent in self.possible_agents]
    observation_space = utils.remove_world_observations_from_space(
        utils.spec_to_space(self._env.observation_spec()[0]))
    self.observation_space = functools.lru_cache(
        maxsize=None)(lambda agent_id: observation_space)
    action_space = utils.spec_to_space(self._env.action_spec()[0])
    self.action_space = functools.lru_cache(maxsize=None)(
        lambda agent_id: action_space)
    self.state_space = utils.spec_to_space(
        self._env.observation_spec()[0]['WORLD.RGB'])

  def state(self):
    return self._env.observation()

  def reset(self, seed=None, **kwargs):
    """See base class."""
    timestep = self._env.reset()
    self.agents = self.possible_agents[:]
    self.num_cycles = 0
    return utils.timestep_to_observations(timestep)

  def step(self, action):
    """See base class."""
    actions = [action[agent] for agent in self.agents]
    timestep = self._env.step(actions)
    rewards = {
        agent: timestep.reward[index] for index, agent in enumerate(self.agents)
    }
    self.num_cycles += 1
    termination = timestep.last()
    terminations = {agent: termination for agent in self.agents}
    truncation = self.num_cycles >= self.max_cycles
    truncations = {agent: truncation for agent in self.agents}
    infos = {agent: {} for agent in self.agents}
    if termination:
      self.agents = []

    observations = utils.timestep_to_observations(timestep)
    return observations, rewards, terminations, truncations, infos

  def close(self):
    """See base class."""
    self._env.close()

  def render(self, mode='human', filename=None):
    rgb_arr = self.state()['WORLD.RGB']
    if mode == 'human':
      plt.cla()
      plt.imshow(rgb_arr, interpolation='nearest')
      if filename is None:
        plt.show(block=False)
      else:
        plt.savefig(filename)
      return None
    return rgb_arr


class _ParallelEnv(_MeltingPotPettingZooEnv, gym_utils.EzPickle):
  metadata = {'render_modes': ['human', 'rgb_array']}

  def __init__(self, env_config, max_cycles):
    gym_utils.EzPickle.__init__(self, env_config, max_cycles)
    _MeltingPotPettingZooEnv.__init__(self, env_config, max_cycles)


Error:

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/elliottower/Documents/GitHub/meltingpot/examples/pettingzoo/sb3_train.py", line 197, in <module>
    if __name__ == "__main__":
  File "/Users/elliottower/Documents/GitHub/meltingpot/examples/pettingzoo/sb3_train.py", line 188, in main
    eval_env, eval_freq=eval_freq, best_model_save_path=tensorboard_log)
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/ppo/ppo.py", line 304, in learn
    return super().learn(
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 246, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 174, in collect_rollouts
    new_obs, rewards, dones, infos = env.step(clipped_actions)
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 171, in step
    return self.step_wait()
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 33, in step_wait
    observations, rewards, dones, infos = self.venv.step_wait()
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 95, in step_wait
    observations, rewards, dones, infos = self.venv.step_wait()
ValueError: too many values to unpack (expected 4)

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • If code there is, it is minimal and working
  • If code there is, it is formatted using the markdown code blocks for both code and stack traces.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions