diff --git a/brax/training/acting.py b/brax/training/acting.py index a234446ca..c01e02460 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -15,7 +15,7 @@ """Brax training acting functions.""" import time -from typing import Callable, Sequence, Tuple +from typing import Callable, Optional, Sequence, Tuple from brax import envs from brax.training.types import Metrics @@ -23,7 +23,9 @@ from brax.training.types import PolicyParams from brax.training.types import PRNGKey from brax.training.types import Transition +from jax.experimental import io_callback import jax +import jax.numpy as jnp import numpy as np State = envs.State @@ -58,6 +60,8 @@ def generate_unroll( key: PRNGKey, unroll_length: int, extra_fields: Sequence[str] = (), + render_fn: Optional[Callable[[State], None]] = None, + should_render: jax.Array = jnp.array(False, dtype=bool), ) -> Tuple[State, Transition]: """Collect trajectories of given unroll_length.""" @@ -68,6 +72,14 @@ def f(carry, unused_t): nstate, transition = actor_step( env, state, policy, current_key, extra_fields=extra_fields ) + + if render_fn is not None: + + def render(state: State): + io_callback(render_fn, None, state) + + jax.lax.cond(should_render, render, lambda s: None, nstate) + return (nstate, next_key), transition (final_state, _), data = jax.lax.scan( @@ -115,6 +127,7 @@ def generate_eval_unroll( eval_policy_fn(policy_params), key, unroll_length=episode_length // action_repeat, + should_render=jnp.array(False, dtype=bool), # No rendering during eval )[0] self._generate_eval_unroll = jax.jit(generate_eval_unroll) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 8a61afb1c..d9d26d8fe 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -317,6 +317,11 @@ def train( Returns: Tuple of (make_policy function, network params, metrics) """ + # If the environment exposes a `render_fn`, use it for real-time rendering during training. + render_fn = None + if hasattr(environment, 'render_fn'): + render_fn = environment.render_fn + assert batch_size * num_minibatches % num_envs == 0 _validate_madrona_args( madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env @@ -483,7 +488,7 @@ def convert_data(x: jnp.ndarray): return (optimizer_state, params, key), metrics def training_step( - carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t + carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array, ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) @@ -504,6 +509,8 @@ def f(carry, unused_t): current_key, unroll_length, extra_fields=('truncation', 'episode_metrics', 'episode_done'), + render_fn=render_fn, + should_render=should_render, ) return (next_state, next_key), data @@ -552,10 +559,13 @@ def f(carry, unused_t): return (new_training_state, state, new_key), metrics def training_epoch( - training_state: TrainingState, state: envs.State, key: PRNGKey + training_state: TrainingState, state: envs.State, key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: + training_step_partial = functools.partial( + training_step, should_render=should_render + ) (training_state, state, _), loss_metrics = jax.lax.scan( - training_step, + training_step_partial, (training_state, state, key), (), length=num_training_steps_per_epoch, @@ -567,12 +577,12 @@ def training_epoch( # Note that this is NOT a pure jittable method. def training_epoch_with_timing( - training_state: TrainingState, env_state: envs.State, key: PRNGKey + training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: nonlocal training_walltime t = time.time() training_state, env_state = _strip_weak_type((training_state, env_state)) - result = training_epoch(training_state, env_state, key) + result = training_epoch(training_state, env_state, key, should_render) training_state, env_state, metrics = _strip_weak_type(result) metrics = jax.tree_util.tree_map(jnp.mean, metrics) @@ -695,11 +705,19 @@ def training_epoch_with_timing( logging.info('starting iteration %s %s', it, time.time() - xt) for _ in range(max(num_resets_per_eval, 1)): - # optimization + should_render_py = False + if hasattr(environment, 'should_render'): + should_render_py = bool(environment.should_render) + + should_render_jax = jnp.array(should_render_py, dtype=bool) + should_render_replicated = jax.device_put_replicated( + should_render_jax, jax.local_devices()[:local_devices_to_use] + ) + epoch_key, local_key = jax.random.split(local_key) epoch_keys = jax.random.split(epoch_key, local_devices_to_use) (training_state, env_state, training_metrics) = ( - training_epoch_with_timing(training_state, env_state, epoch_keys) + training_epoch_with_timing(training_state, env_state, epoch_keys, should_render_replicated) ) current_step = int(_unpmap(training_state.env_steps))