Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ def make_ppo_networks(
observation_size: types.ObservationSize,
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
encoder_hidden_sizes: Sequence[int] = (256, 256),
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
activation: networks.ActivationFn = linen.swish,
policy_obs_key: str = 'state',
value_obs_key: str = 'state',
distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal',
distribution_type: Literal[
'normal', 'tanh_normal', 'history_encoder'
] = 'tanh_normal',
noise_std_type: Literal['scalar', 'log'] = 'scalar',
init_noise_std: float = 1.0,
state_dependent_std: bool = False,
Expand All @@ -89,7 +92,7 @@ def make_ppo_networks(
value_kernel_init_kwargs = value_network_kernel_init_kwargs or {}

parametric_action_distribution: distribution.ParametricDistribution
if distribution_type == 'normal':
if distribution_type in ['normal', 'history_encoder']:
parametric_action_distribution = distribution.NormalDistribution(
event_size=action_size
)
Expand All @@ -106,6 +109,7 @@ def make_ppo_networks(
parametric_action_distribution.param_size,
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
encoder_hidden_sizes=encoder_hidden_sizes,
hidden_layer_sizes=policy_hidden_layer_sizes,
activation=activation,
obs_key=policy_obs_key,
Expand Down
89 changes: 88 additions & 1 deletion brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,92 @@ def __call__(self, obs):

return mean_params, jnp.broadcast_to(std_params, mean_params.shape)

class PolicyModuleWithStdAndHistoryEncoder(linen.Module):
"""Policy module with learnable mean and standard deviation."""

param_size: int
encoder_hidden_sizes: Sequence[int]
hidden_layer_sizes: Sequence[int]
activation: ActivationFn
kernel_init: jax.nn.initializers.Initializer
layer_norm: bool
noise_std_type: Literal['scalar', 'log']
init_noise_std: float
state_dependent_std: bool = False
# TODO(saminda): remove this hardcoded proprio and history size.
# This is a temporary solution to unblock the development.
# PolicyModuleWithStdAndHistoryEncoder is duplicated from networks.py
proprio_dim: int = 143

@linen.compact
def __call__(self, obs):
if self.noise_std_type not in ['scalar', 'log']:
raise ValueError(
f'Unsupported noise std type: {self.noise_std_type}. Must be one of'
' "scalar" or "log".'
)

# Expecting obs to be a dictionary with 'proprio' and 'history'
proprio = obs[..., : self.proprio_dim]
history = obs[..., self.proprio_dim :]

# 1. Encode History
latent = MLP(
layer_sizes=list(self.encoder_hidden_sizes), activation=self.activation
)(history)

# 2. Concatenate with Proprioception
policy_input = jnp.concatenate([proprio, latent], axis=-1)

# 3. Main Policy MLP
outputs = MLP(
layer_sizes=list(self.hidden_layer_sizes),
activation=self.activation,
kernel_init=self.kernel_init,
layer_norm=self.layer_norm,
activate_final=True,
)(policy_input)

mean_params = linen.Dense(
self.param_size,
kernel_init=self.kernel_init,
)(outputs)

if self.state_dependent_std:
log_std_output = linen.Dense(
self.param_size, kernel_init=self.kernel_init
)(outputs)
if self.noise_std_type == 'log':
std_params = jnp.exp(log_std_output)
else:
std_params = log_std_output
else:
if self.noise_std_type == 'scalar':
std_module = Param(
self.init_noise_std, size=self.param_size, name='std_param'
)
else:
std_module = LogParam(
self.init_noise_std, size=self.param_size, name='std_logparam'
)
std_params = std_module()

return mean_params, jnp.broadcast_to(std_params, mean_params.shape)


def make_policy_network(
param_size: int,
obs_size: types.ObservationSize,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
encoder_hidden_sizes: Sequence[int] = (256, 256),
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False,
obs_key: str = 'state',
distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal',
distribution_type: Literal[
'normal', 'tanh_normal', 'history_encoder'
] = 'tanh_normal',
noise_std_type: Literal['scalar', 'log'] = 'scalar',
init_noise_std: float = 1.0,
state_dependent_std: bool = False,
Expand All @@ -396,6 +471,18 @@ def make_policy_network(
init_noise_std=init_noise_std,
state_dependent_std=state_dependent_std,
)
elif distribution_type == 'history_encoder':
policy_module = PolicyModuleWithStdAndHistoryEncoder(
param_size=param_size,
encoder_hidden_sizes=encoder_hidden_sizes,
hidden_layer_sizes=hidden_layer_sizes,
activation=activation,
kernel_init=kernel_init,
layer_norm=layer_norm,
noise_std_type=noise_std_type,
init_noise_std=init_noise_std,
state_dependent_std=state_dependent_std,
)
else:
raise ValueError(
f'Unsupported distribution type: {distribution_type}. Must be one'
Expand Down