-
Notifications
You must be signed in to change notification settings - Fork 338
Add real-time rendering callback #634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 21 commits
29b9df6
8ff1653
98f8fb5
294a742
44e3cdf
a635053
15cd3f9
23ba14a
4b8535b
7538152
9fdd365
260850f
4ea3687
0350739
afdd457
f354c39
da756f6
2373e7a
06bdb45
eaa1398
ddc6d1e
7fe573c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,15 +15,17 @@ | |
| """Brax training acting functions.""" | ||
|
|
||
| import time | ||
| from typing import Callable, Sequence, Tuple | ||
| from typing import Callable, Optional, Sequence, Tuple | ||
|
|
||
| from brax import envs | ||
| from brax.training.types import Metrics | ||
| from brax.training.types import Policy | ||
| from brax.training.types import PolicyParams | ||
| from brax.training.types import PRNGKey | ||
| from brax.training.types import Transition | ||
| from jax.experimental import io_callback | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
|
|
||
| State = envs.State | ||
|
|
@@ -58,6 +60,8 @@ def generate_unroll( | |
| key: PRNGKey, | ||
| unroll_length: int, | ||
| extra_fields: Sequence[str] = (), | ||
| render_fn: Optional[Callable[[State], None]] = None, | ||
| should_render: jax.Array = jnp.array(False, dtype=jnp.bool_), | ||
| ) -> Tuple[State, Transition]: | ||
| """Collect trajectories of given unroll_length.""" | ||
|
|
||
|
|
@@ -68,6 +72,13 @@ def f(carry, unused_t): | |
| nstate, transition = actor_step( | ||
| env, state, policy, current_key, extra_fields=extra_fields | ||
| ) | ||
|
|
||
| def render(state: State): | ||
| if render_fn is None: | ||
| return | ||
| io_callback(render_fn, None, state) | ||
|
|
||
| jax.lax.cond(should_render, render, lambda s: None, nstate) | ||
| return (nstate, next_key), transition | ||
|
|
||
| (final_state, _), data = jax.lax.scan( | ||
|
|
@@ -115,6 +126,7 @@ def generate_eval_unroll( | |
| eval_policy_fn(policy_params), | ||
| key, | ||
| unroll_length=episode_length // action_repeat, | ||
| should_render=jnp.array(False, dtype=jnp.bool_), # No rendering during eval | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're not passing
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the argu |
||
| )[0] | ||
|
|
||
| self._generate_eval_unroll = jax.jit(generate_eval_unroll) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -236,6 +236,9 @@ def train( | |
| # callbacks | ||
| progress_fn: Callable[[int, Metrics], None] = lambda *args: None, | ||
| policy_params_fn: Callable[..., None] = lambda *args: None, | ||
| # rendering | ||
| render_fn: Optional[Callable[[envs.State], None]] = None, | ||
| should_render: jax.Array = jnp.array(True, dtype=jnp.bool_), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed line #241 |
||
| # checkpointing | ||
| save_checkpoint_path: Optional[str] = None, | ||
| restore_checkpoint_path: Optional[str] = None, | ||
|
|
@@ -317,6 +320,11 @@ def train( | |
| Returns: | ||
| Tuple of (make_policy function, network params, metrics) | ||
| """ | ||
| # If the environment is wrapped with ViewerWrapper, use its rendering functions. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure what ViewerWrapper is, maybe update the comment?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to As |
||
| render_fn = None | ||
| if hasattr(environment, 'render_fn'): | ||
| render_fn = environment.render_fn | ||
|
|
||
| assert batch_size * num_minibatches % num_envs == 0 | ||
| _validate_madrona_args( | ||
| madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env | ||
|
|
@@ -483,7 +491,7 @@ def convert_data(x: jnp.ndarray): | |
| return (optimizer_state, params, key), metrics | ||
|
|
||
| def training_step( | ||
| carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t | ||
| carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array, | ||
| ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: | ||
| training_state, state, key = carry | ||
| key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) | ||
|
|
@@ -504,6 +512,8 @@ def f(carry, unused_t): | |
| current_key, | ||
| unroll_length, | ||
| extra_fields=('truncation', 'episode_metrics', 'episode_done'), | ||
| render_fn=render_fn, | ||
| should_render=should_render, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, you maybe can get away without the bool |
||
| ) | ||
| return (next_state, next_key), data | ||
|
|
||
|
|
@@ -552,10 +562,13 @@ def f(carry, unused_t): | |
| return (new_training_state, state, new_key), metrics | ||
|
|
||
| def training_epoch( | ||
| training_state: TrainingState, state: envs.State, key: PRNGKey | ||
| training_state: TrainingState, state: envs.State, key: PRNGKey, should_render: jax.Array, | ||
| ) -> Tuple[TrainingState, envs.State, Metrics]: | ||
| training_step_partial = functools.partial( | ||
| training_step, should_render=should_render | ||
| ) | ||
| (training_state, state, _), loss_metrics = jax.lax.scan( | ||
| training_step, | ||
| training_step_partial, | ||
| (training_state, state, key), | ||
| (), | ||
| length=num_training_steps_per_epoch, | ||
|
|
@@ -567,12 +580,12 @@ def training_epoch( | |
|
|
||
| # Note that this is NOT a pure jittable method. | ||
| def training_epoch_with_timing( | ||
| training_state: TrainingState, env_state: envs.State, key: PRNGKey | ||
| training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array, | ||
| ) -> Tuple[TrainingState, envs.State, Metrics]: | ||
| nonlocal training_walltime | ||
| t = time.time() | ||
| training_state, env_state = _strip_weak_type((training_state, env_state)) | ||
| result = training_epoch(training_state, env_state, key) | ||
| result = training_epoch(training_state, env_state, key, should_render) | ||
| training_state, env_state, metrics = _strip_weak_type(result) | ||
|
|
||
| metrics = jax.tree_util.tree_map(jnp.mean, metrics) | ||
|
|
@@ -696,10 +709,21 @@ def training_epoch_with_timing( | |
|
|
||
| for _ in range(max(num_resets_per_eval, 1)): | ||
| # optimization | ||
|
|
||
| # check for rendering dynamically | ||
| should_render_py = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so you're ignoring the arg to
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. deleted args |
||
| if hasattr(environment, 'sender'): | ||
| should_render_py = environment.sender.rendering_enabled | ||
|
|
||
| should_render_jax = jnp.array(should_render_py, dtype=jnp.bool_) | ||
| should_render_replicated = jax.device_put_replicated( | ||
| should_render_jax, jax.local_devices()[:local_devices_to_use] | ||
| ) | ||
|
|
||
| epoch_key, local_key = jax.random.split(local_key) | ||
| epoch_keys = jax.random.split(epoch_key, local_devices_to_use) | ||
| (training_state, env_state, training_metrics) = ( | ||
| training_epoch_with_timing(training_state, env_state, epoch_keys) | ||
| training_epoch_with_timing(training_state, env_state, epoch_keys, should_render_replicated) | ||
| ) | ||
| current_step = int(_unpmap(training_state.env_steps)) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this whole block just be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may get rid of the fixed overhead in your main post, JAX should be ignoring this whole block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review! A JAX Array(bool) for
should_renderinstead of checkingrender_fn is Noneis used because users need to toggle rendering on/off during training without re-JIT. This enables real-time visualization that can be disabled mid-training to restore full training speed.