From 29b9df65e37e79d1ffc0e5d20de28495b84f5f6d Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Sat, 12 Jul 2025 17:53:48 -0400 Subject: [PATCH 01/21] add viewer env wrapper --- brax/envs/wrappers/viewer.py | 64 ++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 brax/envs/wrappers/viewer.py diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py new file mode 100644 index 000000000..27573585a --- /dev/null +++ b/brax/envs/wrappers/viewer.py @@ -0,0 +1,64 @@ +import jax +import jax.numpy as jnp +from brax.envs.base import Wrapper, State, Env +from typing import Optional + +from braxviewer.WebViewer import WebViewer + + +class ViewerWrapper(Wrapper): + """An environment wrapper that sends state to a WebViewer after each step.""" + + def __init__(self, env: Env, viewer: Optional[WebViewer] = None): + """Initializes the ViewerWrapper. + + Args: + env: A Brax environment instance to wrap. + viewer: An optional WebViewer instance for visualizing state. + """ + super().__init__(env) + self.viewer = viewer + + def reset(self, rng: jnp.ndarray) -> State: + """Resets the environment and sends the initial state to the viewer. + + Args: + rng: A JAX random number generator. + + Returns: + The initial environment state. + """ + state = self.env.reset(rng) + + if self.viewer is not None: + jax.debug.callback(self.viewer.send_frame, state) + + return state + + def step(self, state: State, action: jnp.ndarray) -> State: + """Performs one environment step and conditionally sends the state to the viewer. + + Args: + state: The current environment state. + action: The action to apply. + + Returns: + The next environment state. + """ + next_state = self.env.step(state, action) + + if self.viewer is not None: + def _send_frame(s): + jax.debug.callback(self.viewer.send_frame, s) + + def _do_nothing(s): + pass + + jax.lax.cond( + self.viewer.rendering_enabled, + _send_frame, + _do_nothing, + operand=state # Matches behavior in acting.py + ) + + return next_state From 8ff1653035a69ecb576f64694fb728c05866316d Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Sun, 13 Jul 2025 00:05:03 -0400 Subject: [PATCH 02/21] wrapper runnable workflow --- brax/training/agents/ppo/train.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 8a61afb1c..51bf4629c 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -393,6 +393,25 @@ def train( # Discard the batch axes over devices and envs. obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) + # Inject env_id into the state info. This allows the viewer to distinguish + # between different parallel environments. + if num_envs > 1 and wrap_env: + # Create a range of IDs for all environments. + env_ids = jnp.arange(num_envs) + # Reshape IDs to match the sharded state shape: (devices, envs_per_device). + sharded_env_ids = jnp.reshape(env_ids, (local_devices_to_use, -1)) + + # Create a function to add IDs to a state. + def add_env_id(state, id): + info = state.info.copy() + info['env_id'] = id + return state.replace(info=info) + + # vmap over the environments on a single device, then pmap over all devices. + # This is a standard and robust JAX pattern for this operation. + env_state = jax.pmap(jax.vmap(add_env_id))(env_state, sharded_env_ids) + + normalize = lambda x, y: x if normalize_observations: normalize = running_statistics.normalize From 98f8fb5a2700d582150f49a7713b2e6e90102b3b Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Sun, 13 Jul 2025 11:42:43 -0400 Subject: [PATCH 03/21] io_callback --- brax/envs/wrappers/viewer.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py index 27573585a..e657a35d2 100644 --- a/brax/envs/wrappers/viewer.py +++ b/brax/envs/wrappers/viewer.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +from jax.experimental import io_callback from brax.envs.base import Wrapper, State, Env from typing import Optional @@ -31,7 +32,8 @@ def reset(self, rng: jnp.ndarray) -> State: state = self.env.reset(rng) if self.viewer is not None: - jax.debug.callback(self.viewer.send_frame, state) + # The check for rendering enabled is now inside the viewer's send_frame + io_callback(self.viewer.send_frame, None, state) return state @@ -48,17 +50,7 @@ def step(self, state: State, action: jnp.ndarray) -> State: next_state = self.env.step(state, action) if self.viewer is not None: - def _send_frame(s): - jax.debug.callback(self.viewer.send_frame, s) - - def _do_nothing(s): - pass - - jax.lax.cond( - self.viewer.rendering_enabled, - _send_frame, - _do_nothing, - operand=state # Matches behavior in acting.py - ) + # The check for rendering enabled is now inside the viewer's send_frame + io_callback(self.viewer.send_frame, None, next_state) return next_state From 294a742d752301519aee25076362dea8aa13fe22 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Tue, 15 Jul 2025 20:57:03 -0400 Subject: [PATCH 04/21] toggle --- brax/envs/wrappers/viewer.py | 63 ++++++++++++++++++++++++++++++- brax/training/acting.py | 13 ++++++- brax/training/agents/ppo/train.py | 3 ++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py index e657a35d2..5a8399bd4 100644 --- a/brax/envs/wrappers/viewer.py +++ b/brax/envs/wrappers/viewer.py @@ -2,7 +2,8 @@ import jax.numpy as jnp from jax.experimental import io_callback from brax.envs.base import Wrapper, State, Env -from typing import Optional +from brax.envs.wrappers.training import VmapWrapper +from typing import Optional, Callable from braxviewer.WebViewer import WebViewer @@ -38,7 +39,7 @@ def reset(self, rng: jnp.ndarray) -> State: return state def step(self, state: State, action: jnp.ndarray) -> State: - """Performs one environment step and conditionally sends the state to the viewer. + """Performs one environment step and sends the state to the viewer. Args: state: The current environment state. @@ -54,3 +55,61 @@ def step(self, state: State, action: jnp.ndarray) -> State: io_callback(self.viewer.send_frame, None, next_state) return next_state + + +class RenderableVmapWrapper(VmapWrapper): + """A VmapWrapper that supports conditional rendering with optimal performance. + + This wrapper implements the "Pass-the-Flag" pattern to avoid performance + degradation when using conditional callbacks in vmapped JAX functions. + """ + + def __init__(self, env: Env, batch_size: Optional[int] = None): + super().__init__(env, batch_size) + + def step_with_render( + self, + state: State, + action: jax.Array, + should_render: bool, + render_fn: Optional[Callable[[State], None]] + ) -> State: + """Performs a batched environment step with conditional rendering. + + Args: + state: The current batched environment state. + action: The batched actions to apply. + should_render: A scalar boolean flag controlling rendering. + render_fn: Optional callback function for rendering. + + Returns: + The next batched environment state. + """ + from jax import lax + + # First, perform the environment step + next_state = self.env.step(state, action) + + # Then, conditionally render using lax.cond with scalar flag + if render_fn is not None: + def _render_branch(batched_state): + # True branch: render all environments + def send_all_frames(state_batch): + num_envs = state_batch.pipeline_state.q.shape[0] + for i in range(num_envs): + # Extract single environment state + single_state = jax.tree_util.tree_map(lambda x: x[i], state_batch) + render_fn(single_state) + + io_callback(send_all_frames, None, batched_state) + return 0 # dummy return value + + def _no_op_branch(_): + # False branch: do nothing + return 0 # dummy return value + + # Use lax.cond with scalar should_render flag + # This preserves true conditional execution in vmap contexts + _ = lax.cond(should_render, _render_branch, _no_op_branch, next_state) + + return next_state diff --git a/brax/training/acting.py b/brax/training/acting.py index a234446ca..987b990af 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -15,7 +15,7 @@ """Brax training acting functions.""" import time -from typing import Callable, Sequence, Tuple +from typing import Callable, Optional, Sequence, Tuple from brax import envs from brax.training.types import Metrics @@ -36,10 +36,14 @@ def actor_step( policy: Policy, key: PRNGKey, extra_fields: Sequence[str] = (), + render_fn: Optional[Callable[[State], None]] = None, ) -> Tuple[State, Transition]: """Collect data.""" actions, policy_extras = policy(env_state.obs, key) nstate = env.step(env_state, actions) + if render_fn is not None: + from jax.experimental import io_callback + io_callback(render_fn, None, nstate) state_extras = {x: nstate.info[x] for x in extra_fields} return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray observation=env_state.obs, @@ -58,6 +62,7 @@ def generate_unroll( key: PRNGKey, unroll_length: int, extra_fields: Sequence[str] = (), + render_fn: Optional[Callable[[State], None]] = None, ) -> Tuple[State, Transition]: """Collect trajectories of given unroll_length.""" @@ -66,10 +71,14 @@ def f(carry, unused_t): state, current_key = carry current_key, next_key = jax.random.split(current_key) nstate, transition = actor_step( - env, state, policy, current_key, extra_fields=extra_fields + env, state, policy, current_key, + extra_fields=extra_fields, + render_fn=render_fn ) return (nstate, next_key), transition + # Pass should_render and render_fn as static arguments to scan + # This ensures they are treated as constants and don't interfere with JIT (final_state, _), data = jax.lax.scan( f, (env_state, key), (), length=unroll_length ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 51bf4629c..fb979b570 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -236,6 +236,8 @@ def train( # callbacks progress_fn: Callable[[int, Metrics], None] = lambda *args: None, policy_params_fn: Callable[..., None] = lambda *args: None, + # rendering + render_fn: Optional[Callable[[envs.State], None]] = None, # checkpointing save_checkpoint_path: Optional[str] = None, restore_checkpoint_path: Optional[str] = None, @@ -523,6 +525,7 @@ def f(carry, unused_t): current_key, unroll_length, extra_fields=('truncation', 'episode_metrics', 'episode_done'), + render_fn=render_fn, ) return (next_state, next_key), data From 44e3cdf61b95410efe183789c8e6db7a232bfaf8 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Tue, 15 Jul 2025 21:36:22 -0400 Subject: [PATCH 05/21] coding style refine --- brax/training/acting.py | 4 +--- brax/training/agents/ppo/train.py | 4 ---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 987b990af..28ab63500 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -23,6 +23,7 @@ from brax.training.types import PolicyParams from brax.training.types import PRNGKey from brax.training.types import Transition +from jax.experimental import io_callback import jax import numpy as np @@ -42,7 +43,6 @@ def actor_step( actions, policy_extras = policy(env_state.obs, key) nstate = env.step(env_state, actions) if render_fn is not None: - from jax.experimental import io_callback io_callback(render_fn, None, nstate) state_extras = {x: nstate.info[x] for x in extra_fields} return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray @@ -77,8 +77,6 @@ def f(carry, unused_t): ) return (nstate, next_key), transition - # Pass should_render and render_fn as static arguments to scan - # This ensures they are treated as constants and don't interfere with JIT (final_state, _), data = jax.lax.scan( f, (env_state, key), (), length=unroll_length ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index fb979b570..453bf1ac3 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -395,8 +395,6 @@ def train( # Discard the batch axes over devices and envs. obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) - # Inject env_id into the state info. This allows the viewer to distinguish - # between different parallel environments. if num_envs > 1 and wrap_env: # Create a range of IDs for all environments. env_ids = jnp.arange(num_envs) @@ -409,8 +407,6 @@ def add_env_id(state, id): info['env_id'] = id return state.replace(info=info) - # vmap over the environments on a single device, then pmap over all devices. - # This is a standard and robust JAX pattern for this operation. env_state = jax.pmap(jax.vmap(add_env_id))(env_state, sharded_env_ids) From a6350531105ceffaecbbd2bc7ee575b065e6376e Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 16 Jul 2025 01:07:39 -0400 Subject: [PATCH 06/21] add cond for should render --- brax/training/acting.py | 28 +++++++++++++++++++++++----- brax/training/agents/ppo/train.py | 2 ++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 28ab63500..1e5832371 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -25,25 +25,32 @@ from brax.training.types import Transition from jax.experimental import io_callback import jax +import jax.numpy as jnp import numpy as np State = envs.State Env = envs.Env +# Define render callback outside of JIT-compiled functions +def _do_render(state, render_fn): + if render_fn is not None: + io_callback(render_fn, None, state) + return None + +def _do_nothing(state): + return None + def actor_step( env: Env, env_state: State, policy: Policy, key: PRNGKey, extra_fields: Sequence[str] = (), - render_fn: Optional[Callable[[State], None]] = None, ) -> Tuple[State, Transition]: """Collect data.""" actions, policy_extras = policy(env_state.obs, key) nstate = env.step(env_state, actions) - if render_fn is not None: - io_callback(render_fn, None, nstate) state_extras = {x: nstate.info[x] for x in extra_fields} return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray observation=env_state.obs, @@ -63,6 +70,7 @@ def generate_unroll( unroll_length: int, extra_fields: Sequence[str] = (), render_fn: Optional[Callable[[State], None]] = None, + should_render: jax.Array = jnp.array(True, dtype=jnp.bool_), ) -> Tuple[State, Transition]: """Collect trajectories of given unroll_length.""" @@ -72,9 +80,18 @@ def f(carry, unused_t): current_key, next_key = jax.random.split(current_key) nstate, transition = actor_step( env, state, policy, current_key, - extra_fields=extra_fields, - render_fn=render_fn + extra_fields=extra_fields + ) + + # Use jax.lax.cond to avoid io_callback when should_render=False + # Only render if should_render is True (render_fn is checked outside JIT) + jax.lax.cond( + should_render, + lambda s: _do_render(s, render_fn), + lambda s: _do_nothing(s), + operand=nstate ) + return (nstate, next_key), transition (final_state, _), data = jax.lax.scan( @@ -122,6 +139,7 @@ def generate_eval_unroll( eval_policy_fn(policy_params), key, unroll_length=episode_length // action_repeat, + should_render=jnp.array(False, dtype=jnp.bool_), # No rendering during eval )[0] self._generate_eval_unroll = jax.jit(generate_eval_unroll) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 453bf1ac3..e27ede6cb 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -238,6 +238,7 @@ def train( policy_params_fn: Callable[..., None] = lambda *args: None, # rendering render_fn: Optional[Callable[[envs.State], None]] = None, + should_render: jax.Array = jnp.array(True, dtype=jnp.bool_), # checkpointing save_checkpoint_path: Optional[str] = None, restore_checkpoint_path: Optional[str] = None, @@ -522,6 +523,7 @@ def f(carry, unused_t): unroll_length, extra_fields=('truncation', 'episode_metrics', 'episode_done'), render_fn=render_fn, + should_render=should_render, ) return (next_state, next_key), data From 15cd3f9f3d7cfd08dd42ec6403be4f444756b0cc Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Sat, 26 Jul 2025 14:59:38 -0400 Subject: [PATCH 07/21] update ViewerWrapper --- brax/envs/wrappers/viewer.py | 126 ++++++------------------------ brax/training/agents/ppo/train.py | 5 ++ 2 files changed, 31 insertions(+), 100 deletions(-) diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py index 5a8399bd4..a71a6e4a7 100644 --- a/brax/envs/wrappers/viewer.py +++ b/brax/envs/wrappers/viewer.py @@ -1,115 +1,41 @@ +from brax.envs import Env, State, Wrapper import jax import jax.numpy as jnp -from jax.experimental import io_callback -from brax.envs.base import Wrapper, State, Env -from brax.envs.wrappers.training import VmapWrapper -from typing import Optional, Callable - -from braxviewer.WebViewer import WebViewer - class ViewerWrapper(Wrapper): - """An environment wrapper that sends state to a WebViewer after each step.""" + """A wrapper that provides rendering functionality for a Brax viewer.""" - def __init__(self, env: Env, viewer: Optional[WebViewer] = None): + def __init__(self, env: Env, viewer): """Initializes the ViewerWrapper. Args: - env: A Brax environment instance to wrap. - viewer: An optional WebViewer instance for visualizing state. + env: The environment to wrap. + viewer: An instance of a viewer (e.g., WebViewerBatched) that has a + `send_frame` method and a `rendering_enabled` property. """ super().__init__(env) self.viewer = viewer - def reset(self, rng: jnp.ndarray) -> State: - """Resets the environment and sends the initial state to the viewer. - - Args: - rng: A JAX random number generator. - - Returns: - The initial environment state. - """ - state = self.env.reset(rng) - - if self.viewer is not None: - # The check for rendering enabled is now inside the viewer's send_frame - io_callback(self.viewer.send_frame, None, state) - - return state - - def step(self, state: State, action: jnp.ndarray) -> State: - """Performs one environment step and sends the state to the viewer. + @property + def should_render(self) -> jax.Array: + """Returns a JAX array indicating whether rendering should occur.""" + return jnp.array(True, dtype=jnp.bool_) - Args: - state: The current environment state. - action: The action to apply. - - Returns: - The next environment state. - """ - next_state = self.env.step(state, action) - - if self.viewer is not None: - # The check for rendering enabled is now inside the viewer's send_frame - io_callback(self.viewer.send_frame, None, next_state) - - return next_state + def render_fn(self, state: State): + """The function to be called for rendering a state. - -class RenderableVmapWrapper(VmapWrapper): - """A VmapWrapper that supports conditional rendering with optimal performance. - - This wrapper implements the "Pass-the-Flag" pattern to avoid performance - degradation when using conditional callbacks in vmapped JAX functions. - """ - - def __init__(self, env: Env, batch_size: Optional[int] = None): - super().__init__(env, batch_size) - - def step_with_render( - self, - state: State, - action: jax.Array, - should_render: bool, - render_fn: Optional[Callable[[State], None]] - ) -> State: - """Performs a batched environment step with conditional rendering. - - Args: - state: The current batched environment state. - action: The batched actions to apply. - should_render: A scalar boolean flag controlling rendering. - render_fn: Optional callback function for rendering. - - Returns: - The next batched environment state. + This function is designed to be used with `jax.experimental.io_callback`. + It sends a single, unbatched state to the viewer. """ - from jax import lax - - # First, perform the environment step - next_state = self.env.step(state, action) - - # Then, conditionally render using lax.cond with scalar flag - if render_fn is not None: - def _render_branch(batched_state): - # True branch: render all environments - def send_all_frames(state_batch): - num_envs = state_batch.pipeline_state.q.shape[0] - for i in range(num_envs): - # Extract single environment state - single_state = jax.tree_util.tree_map(lambda x: x[i], state_batch) - render_fn(single_state) - - io_callback(send_all_frames, None, batched_state) - return 0 # dummy return value - - def _no_op_branch(_): - # False branch: do nothing - return 0 # dummy return value - - # Use lax.cond with scalar should_render flag - # This preserves true conditional execution in vmap contexts - _ = lax.cond(should_render, _render_branch, _no_op_branch, next_state) - - return next_state + if not self.viewer.rendering_enabled: + return + + # If the state is batched, iterate and send each frame. + if state.pipeline_state.q.ndim > 1: + num_envs = state.pipeline_state.q.shape[0] + for i in range(num_envs): + single_state = jax.tree_util.tree_map(lambda x: x[i], state) + self.viewer.send_frame(single_state) + else: + # If the state is not batched, send it directly. + self.viewer.send_frame(state) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index e27ede6cb..5b7685f24 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -320,6 +320,11 @@ def train( Returns: Tuple of (make_policy function, network params, metrics) """ + # If the environment is wrapped with ViewerWrapper, use its rendering functions. + if hasattr(environment, 'render_fn'): + render_fn = environment.render_fn + should_render = environment.should_render + assert batch_size * num_minibatches % num_envs == 0 _validate_madrona_args( madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env From 23ba14a7de953ddd71da99523d09e139169f19e5 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Mon, 28 Jul 2025 19:47:22 -0400 Subject: [PATCH 08/21] type check for wrapper --- brax/envs/wrappers/viewer.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py index a71a6e4a7..cb45d3349 100644 --- a/brax/envs/wrappers/viewer.py +++ b/brax/envs/wrappers/viewer.py @@ -1,16 +1,29 @@ from brax.envs import Env, State, Wrapper +from typing import Protocol import jax import jax.numpy as jnp + +class IsViewer(Protocol): + """A protocol for viewers that can be used with ViewerWrapper.""" + + @property + def rendering_enabled(self) -> bool: + ... + + def send_frame(self, state: State): + ... + + class ViewerWrapper(Wrapper): """A wrapper that provides rendering functionality for a Brax viewer.""" - def __init__(self, env: Env, viewer): + def __init__(self, env: Env, viewer: IsViewer): """Initializes the ViewerWrapper. Args: env: The environment to wrap. - viewer: An instance of a viewer (e.g., WebViewerBatched) that has a + viewer: An instance of a viewer (e.g., WebViewer or WebViewerParallel) that has a `send_frame` method and a `rendering_enabled` property. """ super().__init__(env) From 4b8535b8d7934a786e1890f90a5db7447dea2cba Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Tue, 29 Jul 2025 20:24:57 -0400 Subject: [PATCH 09/21] optimize disabled render speed --- brax/envs/wrappers/viewer.py | 13 +++---- brax/training/acting.py | 26 ++++++-------- brax/training/agents/ppo/train.py | 56 +++++++++++++++++-------------- 3 files changed, 49 insertions(+), 46 deletions(-) diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py index cb45d3349..d7e80eaa5 100644 --- a/brax/envs/wrappers/viewer.py +++ b/brax/envs/wrappers/viewer.py @@ -11,7 +11,7 @@ class IsViewer(Protocol): def rendering_enabled(self) -> bool: ... - def send_frame(self, state: State): + def send_frame(self, state: State, env_id: int = 0): ... @@ -38,17 +38,18 @@ def render_fn(self, state: State): """The function to be called for rendering a state. This function is designed to be used with `jax.experimental.io_callback`. - It sends a single, unbatched state to the viewer. + When called within the callback, the batched state is a concrete value. """ if not self.viewer.rendering_enabled: return - # If the state is batched, iterate and send each frame. if state.pipeline_state.q.ndim > 1: + # The state is batched, so we iterate through it in plain Python. num_envs = state.pipeline_state.q.shape[0] for i in range(num_envs): + # Extract the state for a single environment. single_state = jax.tree_util.tree_map(lambda x: x[i], state) - self.viewer.send_frame(single_state) + self.viewer.send_frame(single_state, env_id=i) else: - # If the state is not batched, send it directly. - self.viewer.send_frame(state) + # The state is not batched, send it directly. + self.viewer.send_frame(state, env_id=0) diff --git a/brax/training/acting.py b/brax/training/acting.py index 1e5832371..75ac32885 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -52,7 +52,7 @@ def actor_step( actions, policy_extras = policy(env_state.obs, key) nstate = env.step(env_state, actions) state_extras = {x: nstate.info[x] for x in extra_fields} - return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray + return nstate, Transition( observation=env_state.obs, action=actions, reward=nstate.reward, @@ -70,7 +70,7 @@ def generate_unroll( unroll_length: int, extra_fields: Sequence[str] = (), render_fn: Optional[Callable[[State], None]] = None, - should_render: jax.Array = jnp.array(True, dtype=jnp.bool_), + should_render: jax.Array = jnp.array(False, dtype=jnp.bool_), ) -> Tuple[State, Transition]: """Collect trajectories of given unroll_length.""" @@ -79,19 +79,15 @@ def f(carry, unused_t): state, current_key = carry current_key, next_key = jax.random.split(current_key) nstate, transition = actor_step( - env, state, policy, current_key, - extra_fields=extra_fields + env, state, policy, current_key, extra_fields=extra_fields ) - - # Use jax.lax.cond to avoid io_callback when should_render=False - # Only render if should_render is True (render_fn is checked outside JIT) - jax.lax.cond( - should_render, - lambda s: _do_render(s, render_fn), - lambda s: _do_nothing(s), - operand=nstate - ) - + + def render(state: State): + if render_fn is None: + return + io_callback(render_fn, None, state) + + jax.lax.cond(should_render, render, lambda s: None, nstate) return (nstate, next_key), transition (final_state, _), data = jax.lax.scan( @@ -179,4 +175,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray + return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 5b7685f24..bcef69a12 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -42,7 +42,6 @@ import numpy as np import optax - InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] Metrics = types.Metrics @@ -321,9 +320,9 @@ def train( Tuple of (make_policy function, network params, metrics) """ # If the environment is wrapped with ViewerWrapper, use its rendering functions. + render_fn = None if hasattr(environment, 'render_fn'): render_fn = environment.render_fn - should_render = environment.should_render assert batch_size * num_minibatches % num_envs == 0 _validate_madrona_args( @@ -401,21 +400,6 @@ def train( # Discard the batch axes over devices and envs. obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) - if num_envs > 1 and wrap_env: - # Create a range of IDs for all environments. - env_ids = jnp.arange(num_envs) - # Reshape IDs to match the sharded state shape: (devices, envs_per_device). - sharded_env_ids = jnp.reshape(env_ids, (local_devices_to_use, -1)) - - # Create a function to add IDs to a state. - def add_env_id(state, id): - info = state.info.copy() - info['env_id'] = id - return state.replace(info=info) - - env_state = jax.pmap(jax.vmap(add_env_id))(env_state, sharded_env_ids) - - normalize = lambda x, y: x if normalize_observations: normalize = running_statistics.normalize @@ -506,7 +490,9 @@ def convert_data(x: jnp.ndarray): return (optimizer_state, params, key), metrics def training_step( - carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t + carry: Tuple[TrainingState, envs.State, PRNGKey], + unused_t, + should_render: jax.Array, ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) @@ -577,10 +563,16 @@ def f(carry, unused_t): return (new_training_state, state, new_key), metrics def training_epoch( - training_state: TrainingState, state: envs.State, key: PRNGKey + training_state: TrainingState, + state: envs.State, + key: PRNGKey, + should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: + partial_training_step = functools.partial( + training_step, should_render=should_render + ) (training_state, state, _), loss_metrics = jax.lax.scan( - training_step, + partial_training_step, (training_state, state, key), (), length=num_training_steps_per_epoch, @@ -592,12 +584,15 @@ def training_epoch( # Note that this is NOT a pure jittable method. def training_epoch_with_timing( - training_state: TrainingState, env_state: envs.State, key: PRNGKey + training_state: TrainingState, + env_state: envs.State, + key: PRNGKey, + should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: nonlocal training_walltime t = time.time() training_state, env_state = _strip_weak_type((training_state, env_state)) - result = training_epoch(training_state, env_state, key) + result = training_epoch(training_state, env_state, key, should_render) training_state, env_state, metrics = _strip_weak_type(result) metrics = jax.tree_util.tree_map(jnp.mean, metrics) @@ -720,11 +715,22 @@ def training_epoch_with_timing( logging.info('starting iteration %s %s', it, time.time() - xt) for _ in range(max(num_resets_per_eval, 1)): - # optimization + # check for rendering dynamically + should_render_py = False + if hasattr(environment, 'viewer'): + should_render_py = environment.viewer.rendering_enabled + + should_render_jax = jnp.array(should_render_py, dtype=jnp.bool_) + should_render_replicated = jax.device_put_replicated( + should_render_jax, jax.local_devices()[:local_devices_to_use] + ) + epoch_key, local_key = jax.random.split(local_key) epoch_keys = jax.random.split(epoch_key, local_devices_to_use) (training_state, env_state, training_metrics) = ( - training_epoch_with_timing(training_state, env_state, epoch_keys) + training_epoch_with_timing( + training_state, env_state, epoch_keys, should_render_replicated + ) ) current_step = int(_unpmap(training_state.env_steps)) @@ -784,4 +790,4 @@ def training_epoch_with_timing( )) logging.info('total steps: %s', total_steps) pmap.synchronize_hosts() - return (make_policy, params, metrics) + return (make_policy, params, metrics) \ No newline at end of file From 75381521cb6c66faf5ed1fb44b0a29a2fcb7c295 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Tue, 29 Jul 2025 20:52:23 -0400 Subject: [PATCH 10/21] delete redundent code --- brax/training/acting.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 75ac32885..66153b5e7 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -31,16 +31,6 @@ State = envs.State Env = envs.Env - -# Define render callback outside of JIT-compiled functions -def _do_render(state, render_fn): - if render_fn is not None: - io_callback(render_fn, None, state) - return None - -def _do_nothing(state): - return None - def actor_step( env: Env, env_state: State, From 9fdd3655a3d8e8d972d4af553d17ad47184c87ec Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 03:16:16 -0400 Subject: [PATCH 11/21] replace viewer by sender --- brax/envs/wrappers/viewer.py | 55 ------------------------------- brax/training/agents/ppo/train.py | 4 +-- 2 files changed, 2 insertions(+), 57 deletions(-) delete mode 100644 brax/envs/wrappers/viewer.py diff --git a/brax/envs/wrappers/viewer.py b/brax/envs/wrappers/viewer.py deleted file mode 100644 index d7e80eaa5..000000000 --- a/brax/envs/wrappers/viewer.py +++ /dev/null @@ -1,55 +0,0 @@ -from brax.envs import Env, State, Wrapper -from typing import Protocol -import jax -import jax.numpy as jnp - - -class IsViewer(Protocol): - """A protocol for viewers that can be used with ViewerWrapper.""" - - @property - def rendering_enabled(self) -> bool: - ... - - def send_frame(self, state: State, env_id: int = 0): - ... - - -class ViewerWrapper(Wrapper): - """A wrapper that provides rendering functionality for a Brax viewer.""" - - def __init__(self, env: Env, viewer: IsViewer): - """Initializes the ViewerWrapper. - - Args: - env: The environment to wrap. - viewer: An instance of a viewer (e.g., WebViewer or WebViewerParallel) that has a - `send_frame` method and a `rendering_enabled` property. - """ - super().__init__(env) - self.viewer = viewer - - @property - def should_render(self) -> jax.Array: - """Returns a JAX array indicating whether rendering should occur.""" - return jnp.array(True, dtype=jnp.bool_) - - def render_fn(self, state: State): - """The function to be called for rendering a state. - - This function is designed to be used with `jax.experimental.io_callback`. - When called within the callback, the batched state is a concrete value. - """ - if not self.viewer.rendering_enabled: - return - - if state.pipeline_state.q.ndim > 1: - # The state is batched, so we iterate through it in plain Python. - num_envs = state.pipeline_state.q.shape[0] - for i in range(num_envs): - # Extract the state for a single environment. - single_state = jax.tree_util.tree_map(lambda x: x[i], state) - self.viewer.send_frame(single_state, env_id=i) - else: - # The state is not batched, send it directly. - self.viewer.send_frame(state, env_id=0) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index bcef69a12..9dacf9d7d 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -717,8 +717,8 @@ def training_epoch_with_timing( for _ in range(max(num_resets_per_eval, 1)): # check for rendering dynamically should_render_py = False - if hasattr(environment, 'viewer'): - should_render_py = environment.viewer.rendering_enabled + if hasattr(environment, 'sender'): + should_render_py = environment.sender.rendering_enabled should_render_jax = jnp.array(should_render_py, dtype=jnp.bool_) should_render_replicated = jax.device_put_replicated( From 260850f7d98dac4233be8c1a8df642ba1b2ec362 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:03:08 -0400 Subject: [PATCH 12/21] remove unecessary change --- brax/training/acting.py | 4 ++-- brax/training/agents/ppo/train.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 66153b5e7..1c935f75b 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -42,7 +42,7 @@ def actor_step( actions, policy_extras = policy(env_state.obs, key) nstate = env.step(env_state, actions) state_extras = {x: nstate.info[x] for x in extra_fields} - return nstate, Transition( + return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray observation=env_state.obs, action=actions, reward=nstate.reward, @@ -165,4 +165,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 9dacf9d7d..db550d3e2 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -568,11 +568,11 @@ def training_epoch( key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: - partial_training_step = functools.partial( + training_step = functools.partial( training_step, should_render=should_render ) (training_state, state, _), loss_metrics = jax.lax.scan( - partial_training_step, + training_step, (training_state, state, key), (), length=num_training_steps_per_epoch, @@ -715,6 +715,8 @@ def training_epoch_with_timing( logging.info('starting iteration %s %s', it, time.time() - xt) for _ in range(max(num_resets_per_eval, 1)): + # optimization + # check for rendering dynamically should_render_py = False if hasattr(environment, 'sender'): From 4ea3687c3b1d0b0dcf006b65daa1170da21a8203 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:15:34 -0400 Subject: [PATCH 13/21] remove unecessary changes --- brax/training/acting.py | 2 +- brax/training/agents/ppo/train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 1c935f75b..a634dec65 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -165,4 +165,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index db550d3e2..ec53f90f4 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -42,6 +42,7 @@ import numpy as np import optax + InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] Metrics = types.Metrics @@ -792,4 +793,4 @@ def training_epoch_with_timing( )) logging.info('total steps: %s', total_steps) pmap.synchronize_hosts() - return (make_policy, params, metrics) \ No newline at end of file + return (make_policy, params, metrics) From 035073921ebf54813a893dd16bfd46af8f24f5a9 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:19:12 -0400 Subject: [PATCH 14/21] remove unecessary changes --- brax/training/acting.py | 6 ++++-- brax/training/agents/ppo/train.py | 20 ++++---------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index a634dec65..5773f0a08 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -31,6 +31,7 @@ State = envs.State Env = envs.Env + def actor_step( env: Env, env_state: State, @@ -42,7 +43,7 @@ def actor_step( actions, policy_extras = policy(env_state.obs, key) nstate = env.step(env_state, actions) state_extras = {x: nstate.info[x] for x in extra_fields} - return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray + return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray observation=env_state.obs, action=actions, reward=nstate.reward, @@ -165,4 +166,5 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray + \ No newline at end of file diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index ec53f90f4..ae9e98713 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -490,10 +490,7 @@ def convert_data(x: jnp.ndarray): ) return (optimizer_state, params, key), metrics - def training_step( - carry: Tuple[TrainingState, envs.State, PRNGKey], - unused_t, - should_render: jax.Array, + def training_step(carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array, ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) @@ -564,10 +561,7 @@ def f(carry, unused_t): return (new_training_state, state, new_key), metrics def training_epoch( - training_state: TrainingState, - state: envs.State, - key: PRNGKey, - should_render: jax.Array, + training_state: TrainingState, state: envs.State, key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: training_step = functools.partial( training_step, should_render=should_render @@ -584,11 +578,7 @@ def training_epoch( training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME) # Note that this is NOT a pure jittable method. - def training_epoch_with_timing( - training_state: TrainingState, - env_state: envs.State, - key: PRNGKey, - should_render: jax.Array, + def training_epoch_with_timing(training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: nonlocal training_walltime t = time.time() @@ -731,9 +721,7 @@ def training_epoch_with_timing( epoch_key, local_key = jax.random.split(local_key) epoch_keys = jax.random.split(epoch_key, local_devices_to_use) (training_state, env_state, training_metrics) = ( - training_epoch_with_timing( - training_state, env_state, epoch_keys, should_render_replicated - ) + training_epoch_with_timing(training_state, env_state, epoch_keys, should_render_replicated) ) current_step = int(_unpmap(training_state.env_steps)) From afdd4573cc88493b91742e543fbddf21ca7de136 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:20:51 -0400 Subject: [PATCH 15/21] remove unecessary changes --- brax/training/agents/ppo/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index ae9e98713..82f4114d8 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -490,7 +490,8 @@ def convert_data(x: jnp.ndarray): ) return (optimizer_state, params, key), metrics - def training_step(carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array, + def training_step( + carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array, ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) @@ -578,7 +579,8 @@ def training_epoch( training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME) # Note that this is NOT a pure jittable method. - def training_epoch_with_timing(training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array, + def training_epoch_with_timing( + training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: nonlocal training_walltime t = time.time() @@ -781,4 +783,4 @@ def training_epoch_with_timing(training_state: TrainingState, env_state: envs.St )) logging.info('total steps: %s', total_steps) pmap.synchronize_hosts() - return (make_policy, params, metrics) + return (make_policy, params, metrics) \ No newline at end of file From f354c39f6493ae5b3564492e93bd18e0eb01be34 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:21:51 -0400 Subject: [PATCH 16/21] remove unecessary changes --- brax/training/acting.py | 3 +-- brax/training/agents/ppo/train.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 5773f0a08..ccf217d19 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -166,5 +166,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray - \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 82f4114d8..19c8ffc34 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -783,4 +783,4 @@ def training_epoch_with_timing( )) logging.info('total steps: %s', total_steps) pmap.synchronize_hosts() - return (make_policy, params, metrics) \ No newline at end of file + return (make_policy, params, metrics) From da756f6cb6709777c6b72762de97aaf99e96debf Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:22:13 -0400 Subject: [PATCH 17/21] format --- brax/training/acting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index ccf217d19..e205fbb6e 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -166,4 +166,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file From 2373e7a0e19c4a1c5f021ebfdcf7709e18db0b2d Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:22:55 -0400 Subject: [PATCH 18/21] format --- brax/training/acting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index e205fbb6e..ccf217d19 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -166,4 +166,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file From 06bdb456703a23f0bc0f972cd75504f0014ea6ac Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 12:23:37 -0400 Subject: [PATCH 19/21] format --- brax/training/acting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index ccf217d19..2e90aefc7 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -166,4 +166,4 @@ def run_evaluation( **metrics, } - return metrics # pytype: disable=bad-return-type # jax-ndarray \ No newline at end of file + return metrics # pytype: disable=bad-return-type # jax-ndarray From eaa13980ec5ccc5d95dd94675f8afd6325fae190 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Wed, 30 Jul 2025 14:27:42 -0400 Subject: [PATCH 20/21] debug training_step_partial --- brax/training/agents/ppo/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 19c8ffc34..96c0c9143 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -564,11 +564,11 @@ def f(carry, unused_t): def training_epoch( training_state: TrainingState, state: envs.State, key: PRNGKey, should_render: jax.Array, ) -> Tuple[TrainingState, envs.State, Metrics]: - training_step = functools.partial( + training_step_partial = functools.partial( training_step, should_render=should_render ) (training_state, state, _), loss_metrics = jax.lax.scan( - training_step, + training_step_partial, (training_state, state, key), (), length=num_training_steps_per_epoch, From 7fe573c84f34f4d40447cd9b428b069bdd9be4b1 Mon Sep 17 00:00:00 2001 From: Bruce MBP Date: Fri, 28 Nov 2025 17:20:09 -0500 Subject: [PATCH 21/21] remove argus to train() change jax.bool_ to bool now render io_callback will check environment attribute "should_render" --- brax/training/acting.py | 15 ++++++++------- brax/training/agents/ppo/train.py | 16 +++++----------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/brax/training/acting.py b/brax/training/acting.py index 2e90aefc7..c01e02460 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -61,7 +61,7 @@ def generate_unroll( unroll_length: int, extra_fields: Sequence[str] = (), render_fn: Optional[Callable[[State], None]] = None, - should_render: jax.Array = jnp.array(False, dtype=jnp.bool_), + should_render: jax.Array = jnp.array(False, dtype=bool), ) -> Tuple[State, Transition]: """Collect trajectories of given unroll_length.""" @@ -73,12 +73,13 @@ def f(carry, unused_t): env, state, policy, current_key, extra_fields=extra_fields ) - def render(state: State): - if render_fn is None: - return - io_callback(render_fn, None, state) + if render_fn is not None: + + def render(state: State): + io_callback(render_fn, None, state) + + jax.lax.cond(should_render, render, lambda s: None, nstate) - jax.lax.cond(should_render, render, lambda s: None, nstate) return (nstate, next_key), transition (final_state, _), data = jax.lax.scan( @@ -126,7 +127,7 @@ def generate_eval_unroll( eval_policy_fn(policy_params), key, unroll_length=episode_length // action_repeat, - should_render=jnp.array(False, dtype=jnp.bool_), # No rendering during eval + should_render=jnp.array(False, dtype=bool), # No rendering during eval )[0] self._generate_eval_unroll = jax.jit(generate_eval_unroll) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 96c0c9143..d9d26d8fe 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -236,9 +236,6 @@ def train( # callbacks progress_fn: Callable[[int, Metrics], None] = lambda *args: None, policy_params_fn: Callable[..., None] = lambda *args: None, - # rendering - render_fn: Optional[Callable[[envs.State], None]] = None, - should_render: jax.Array = jnp.array(True, dtype=jnp.bool_), # checkpointing save_checkpoint_path: Optional[str] = None, restore_checkpoint_path: Optional[str] = None, @@ -320,10 +317,10 @@ def train( Returns: Tuple of (make_policy function, network params, metrics) """ - # If the environment is wrapped with ViewerWrapper, use its rendering functions. + # If the environment exposes a `render_fn`, use it for real-time rendering during training. render_fn = None if hasattr(environment, 'render_fn'): - render_fn = environment.render_fn + render_fn = environment.render_fn assert batch_size * num_minibatches % num_envs == 0 _validate_madrona_args( @@ -708,14 +705,11 @@ def training_epoch_with_timing( logging.info('starting iteration %s %s', it, time.time() - xt) for _ in range(max(num_resets_per_eval, 1)): - # optimization - - # check for rendering dynamically should_render_py = False - if hasattr(environment, 'sender'): - should_render_py = environment.sender.rendering_enabled + if hasattr(environment, 'should_render'): + should_render_py = bool(environment.should_render) - should_render_jax = jnp.array(should_render_py, dtype=jnp.bool_) + should_render_jax = jnp.array(should_render_py, dtype=bool) should_render_replicated = jax.device_put_replicated( should_render_jax, jax.local_devices()[:local_devices_to_use] )