From f7e453efafd0f282458f2322934c9f5471a4daee Mon Sep 17 00:00:00 2001 From: Saminda Abeyruwan Date: Mon, 1 Dec 2025 17:56:45 -0800 Subject: [PATCH] simple history encoder for brax --- brax/training/agents/ppo/networks.py | 8 ++- brax/training/networks.py | 89 +++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index cc5689a62..1c444fd18 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -70,12 +70,15 @@ def make_ppo_networks( observation_size: types.ObservationSize, action_size: int, preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, + encoder_hidden_sizes: Sequence[int] = (256, 256), policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, activation: networks.ActivationFn = linen.swish, policy_obs_key: str = 'state', value_obs_key: str = 'state', - distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal', + distribution_type: Literal[ + 'normal', 'tanh_normal', 'history_encoder' + ] = 'tanh_normal', noise_std_type: Literal['scalar', 'log'] = 'scalar', init_noise_std: float = 1.0, state_dependent_std: bool = False, @@ -89,7 +92,7 @@ def make_ppo_networks( value_kernel_init_kwargs = value_network_kernel_init_kwargs or {} parametric_action_distribution: distribution.ParametricDistribution - if distribution_type == 'normal': + if distribution_type in ['normal', 'history_encoder']: parametric_action_distribution = distribution.NormalDistribution( event_size=action_size ) @@ -106,6 +109,7 @@ def make_ppo_networks( parametric_action_distribution.param_size, observation_size, preprocess_observations_fn=preprocess_observations_fn, + encoder_hidden_sizes=encoder_hidden_sizes, hidden_layer_sizes=policy_hidden_layer_sizes, activation=activation, obs_key=policy_obs_key, diff --git a/brax/training/networks.py b/brax/training/networks.py index 27b9b9aad..f44621841 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -362,17 +362,92 @@ def __call__(self, obs): return mean_params, jnp.broadcast_to(std_params, mean_params.shape) +class PolicyModuleWithStdAndHistoryEncoder(linen.Module): + """Policy module with learnable mean and standard deviation.""" + + param_size: int + encoder_hidden_sizes: Sequence[int] + hidden_layer_sizes: Sequence[int] + activation: ActivationFn + kernel_init: jax.nn.initializers.Initializer + layer_norm: bool + noise_std_type: Literal['scalar', 'log'] + init_noise_std: float + state_dependent_std: bool = False + # TODO(saminda): remove this hardcoded proprio and history size. + # This is a temporary solution to unblock the development. + # PolicyModuleWithStdAndHistoryEncoder is duplicated from networks.py + proprio_dim: int = 143 + + @linen.compact + def __call__(self, obs): + if self.noise_std_type not in ['scalar', 'log']: + raise ValueError( + f'Unsupported noise std type: {self.noise_std_type}. Must be one of' + ' "scalar" or "log".' + ) + + # Expecting obs to be a dictionary with 'proprio' and 'history' + proprio = obs[..., : self.proprio_dim] + history = obs[..., self.proprio_dim :] + + # 1. Encode History + latent = MLP( + layer_sizes=list(self.encoder_hidden_sizes), activation=self.activation + )(history) + + # 2. Concatenate with Proprioception + policy_input = jnp.concatenate([proprio, latent], axis=-1) + + # 3. Main Policy MLP + outputs = MLP( + layer_sizes=list(self.hidden_layer_sizes), + activation=self.activation, + kernel_init=self.kernel_init, + layer_norm=self.layer_norm, + activate_final=True, + )(policy_input) + + mean_params = linen.Dense( + self.param_size, + kernel_init=self.kernel_init, + )(outputs) + + if self.state_dependent_std: + log_std_output = linen.Dense( + self.param_size, kernel_init=self.kernel_init + )(outputs) + if self.noise_std_type == 'log': + std_params = jnp.exp(log_std_output) + else: + std_params = log_std_output + else: + if self.noise_std_type == 'scalar': + std_module = Param( + self.init_noise_std, size=self.param_size, name='std_param' + ) + else: + std_module = LogParam( + self.init_noise_std, size=self.param_size, name='std_logparam' + ) + std_params = std_module() + + return mean_params, jnp.broadcast_to(std_params, mean_params.shape) + def make_policy_network( param_size: int, obs_size: types.ObservationSize, preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, + encoder_hidden_sizes: Sequence[int] = (256, 256), hidden_layer_sizes: Sequence[int] = (256, 256), activation: ActivationFn = linen.relu, kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), layer_norm: bool = False, obs_key: str = 'state', - distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal', + distribution_type: Literal[ + 'normal', 'tanh_normal', 'history_encoder' + ] = 'tanh_normal', noise_std_type: Literal['scalar', 'log'] = 'scalar', init_noise_std: float = 1.0, state_dependent_std: bool = False, @@ -396,6 +471,18 @@ def make_policy_network( init_noise_std=init_noise_std, state_dependent_std=state_dependent_std, ) + elif distribution_type == 'history_encoder': + policy_module = PolicyModuleWithStdAndHistoryEncoder( + param_size=param_size, + encoder_hidden_sizes=encoder_hidden_sizes, + hidden_layer_sizes=hidden_layer_sizes, + activation=activation, + kernel_init=kernel_init, + layer_norm=layer_norm, + noise_std_type=noise_std_type, + init_noise_std=init_noise_std, + state_dependent_std=state_dependent_std, + ) else: raise ValueError( f'Unsupported distribution type: {distribution_type}. Must be one'