Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion brax/training/acting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
"""Brax training acting functions."""

import time
from typing import Callable, Sequence, Tuple
from typing import Callable, Optional, Sequence, Tuple

from brax import envs
from brax.training.types import Metrics
from brax.training.types import Policy
from brax.training.types import PolicyParams
from brax.training.types import PRNGKey
from brax.training.types import Transition
from jax.experimental import io_callback
import jax
import jax.numpy as jnp
import numpy as np

State = envs.State
Expand Down Expand Up @@ -58,6 +60,8 @@ def generate_unroll(
key: PRNGKey,
unroll_length: int,
extra_fields: Sequence[str] = (),
render_fn: Optional[Callable[[State], None]] = None,
should_render: jax.Array = jnp.array(False, dtype=jnp.bool_),
) -> Tuple[State, Transition]:
"""Collect trajectories of given unroll_length."""

Expand All @@ -68,6 +72,13 @@ def f(carry, unused_t):
nstate, transition = actor_step(
env, state, policy, current_key, extra_fields=extra_fields
)

def render(state: State):
if render_fn is None:
return
io_callback(render_fn, None, state)

jax.lax.cond(should_render, render, lambda s: None, nstate)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this whole block just be

if render_fn:
  io_callback(render_fn, None, state)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may get rid of the fixed overhead in your main post, JAX should be ignoring this whole block

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review! A JAX Array(bool) for should_render instead of checking render_fn is None is used because users need to toggle rendering on/off during training without re-JIT. This enables real-time visualization that can be disabled mid-training to restore full training speed.

return (nstate, next_key), transition

(final_state, _), data = jax.lax.scan(
Expand Down Expand Up @@ -115,6 +126,7 @@ def generate_eval_unroll(
eval_policy_fn(policy_params),
key,
unroll_length=episode_length // action_repeat,
should_render=jnp.array(False, dtype=jnp.bool_), # No rendering during eval
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're not passing render_fn anyways, not sure you really need should_render

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the argu render_fn to train() is removed

)[0]

self._generate_eval_unroll = jax.jit(generate_eval_unroll)
Expand Down
36 changes: 30 additions & 6 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ def train(
# callbacks
progress_fn: Callable[[int, Metrics], None] = lambda *args: None,
policy_params_fn: Callable[..., None] = lambda *args: None,
# rendering
render_fn: Optional[Callable[[envs.State], None]] = None,
should_render: jax.Array = jnp.array(True, dtype=jnp.bool_),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=bool

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed line #241

# checkpointing
save_checkpoint_path: Optional[str] = None,
restore_checkpoint_path: Optional[str] = None,
Expand Down Expand Up @@ -317,6 +320,11 @@ def train(
Returns:
Tuple of (make_policy function, network params, metrics)
"""
# If the environment is wrapped with ViewerWrapper, use its rendering functions.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what ViewerWrapper is, maybe update the comment?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to # If the environment exposes a render_fn, use it for real-time rendering during training.

As render_fn and should_render is removed fro train() args, now render_fn is at environment. render_fn which can be provided by external wrapper.

render_fn = None
if hasattr(environment, 'render_fn'):
render_fn = environment.render_fn

assert batch_size * num_minibatches % num_envs == 0
_validate_madrona_args(
madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env
Expand Down Expand Up @@ -483,7 +491,7 @@ def convert_data(x: jnp.ndarray):
return (optimizer_state, params, key), metrics

def training_step(
carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t
carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array,
) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:
training_state, state, key = carry
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)
Expand All @@ -504,6 +512,8 @@ def f(carry, unused_t):
current_key,
unroll_length,
extra_fields=('truncation', 'episode_metrics', 'episode_done'),
render_fn=render_fn,
should_render=should_render,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, you maybe can get away without the bool should_render

)
return (next_state, next_key), data

Expand Down Expand Up @@ -552,10 +562,13 @@ def f(carry, unused_t):
return (new_training_state, state, new_key), metrics

def training_epoch(
training_state: TrainingState, state: envs.State, key: PRNGKey
training_state: TrainingState, state: envs.State, key: PRNGKey, should_render: jax.Array,
) -> Tuple[TrainingState, envs.State, Metrics]:
training_step_partial = functools.partial(
training_step, should_render=should_render
)
(training_state, state, _), loss_metrics = jax.lax.scan(
training_step,
training_step_partial,
(training_state, state, key),
(),
length=num_training_steps_per_epoch,
Expand All @@ -567,12 +580,12 @@ def training_epoch(

# Note that this is NOT a pure jittable method.
def training_epoch_with_timing(
training_state: TrainingState, env_state: envs.State, key: PRNGKey
training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array,
) -> Tuple[TrainingState, envs.State, Metrics]:
nonlocal training_walltime
t = time.time()
training_state, env_state = _strip_weak_type((training_state, env_state))
result = training_epoch(training_state, env_state, key)
result = training_epoch(training_state, env_state, key, should_render)
training_state, env_state, metrics = _strip_weak_type(result)

metrics = jax.tree_util.tree_map(jnp.mean, metrics)
Expand Down Expand Up @@ -696,10 +709,21 @@ def training_epoch_with_timing(

for _ in range(max(num_resets_per_eval, 1)):
# optimization

# check for rendering dynamically
should_render_py = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you're ignoring the arg to train(), just delete all the args

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted args

if hasattr(environment, 'sender'):
should_render_py = environment.sender.rendering_enabled

should_render_jax = jnp.array(should_render_py, dtype=jnp.bool_)
should_render_replicated = jax.device_put_replicated(
should_render_jax, jax.local_devices()[:local_devices_to_use]
)

epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state, training_metrics) = (
training_epoch_with_timing(training_state, env_state, epoch_keys)
training_epoch_with_timing(training_state, env_state, epoch_keys, should_render_replicated)
)
current_step = int(_unpmap(training_state.env_steps))

Expand Down