...truncated...
[ 126 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False
[ 127 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False
[ 128 ] Hidden shape: torch.Size([1, 1, 64]) | Training mode: False
[ 129 ] Hidden shape: torch.Size([1, 2, 64]) | Training mode: True
Anomalous shape encountered
Slice 0:
tensor([[-0.0829, -0.1313, -0.0784, -0.0005, 0.1998, -0.1344, 0.0718, 0.0770,
0.0633, -0.1303, -0.1619, -0.0804, 0.0221, 0.0895, 0.1033, 0.1790,
-0.0631, -0.0337, 0.0096, 0.0497, 0.1222, -0.0074, 0.0711, -0.0137,
-0.0203, -0.0729, 0.1066, -0.1037, 0.2195, -0.1407, 0.0269, -0.0968,
-0.1034, 0.1163, -0.0931, 0.0071, 0.0914, -0.0213, 0.1505, 0.0700,
0.1196, 0.0809, -0.0368, -0.0342, 0.0384, -0.1669, 0.0109, 0.1535,
-0.0206, -0.1599, -0.0975, -0.0114, -0.0040, 0.0757, 0.0590, -0.0663,
0.0200, -0.0264, -0.2038, -0.2178, -0.0444, -0.1478, 0.0482, -0.0903]],
grad_fn=)
Slice 1:
tensor([[-0.0356, -0.0018, 0.0154, 0.0565, -0.0229, 0.0009, 0.0023, 0.0469,
0.0245, 0.0436, 0.0444, 0.0029, -0.0717, -0.0088, -0.0291, 0.0016,
0.0177, -0.0729, -0.0166, -0.0091, -0.0412, -0.0033, 0.0781, -0.0058,
0.0177, -0.0253, 0.0068, -0.0249, -0.0008, 0.0197, -0.0673, 0.0052,
-0.0183, -0.0185, 0.0483, 0.0335, -0.0625, -0.0217, -0.0155, -0.0839,
-0.0129, 0.0036, 0.0642, -0.0446, -0.0107, -0.0002, -0.0331, -0.0013,
-0.0827, -0.0065, 0.0314, 0.0236, 0.0364, 0.0125, 0.0256, 0.0098,
0.0923, -0.0712, -0.0166, -0.0073, 0.0583, 0.0363, -0.0390, 0.0089]],
grad_fn=)
❓ Question
First, thank you for the great work on this package (and happy Friday)!
I have a question about the way sequences are handled in
RecurrentRolloutBuffer. The context is that I am trying to extract the hidden and cell states at every time step from an LSTM actor, but during training the shapes of the internal hidden and cell states change, and I don't quite understand why since I'm using a single environment (which seems to be a factor, I've detailed the steps below - 🐻 with me).Here is my minimal setup:
This produces the following output:
Expand output
The second dimension of the shape varies between 1 and 3, which makes it difficult to determine what the actual hidden state is at the current time step. The slices are not identical - breaking the first time
shape[1]is not 1 produces the following:Expand output
After some digging, it seems that this is caused at least partially by
create_sequencersL64 insb3_contrib/common/recurrent/buffers.py. The flow is as follows:agent.learn()callsOnPolicyAlgorithm.learn()in SB3, which in turn callsRecurrentPPO.collect_rollouts()(L324).RecurrentPPO.collect_rollouts()callsRecurrentPPO.policy.forward()on L242.RecurrentActorCriticPolicy._process_sequence()(L237).RecurrentActorCriticPolicy._process_sequence(),nseqis set to whatever the second dimension of the hidden state is (L182).Now, at the beginning this produces the expected outcome because the condition on L191 is satisfied for the entire duration of the
whileloop inRecurrentPPO.collect_rollouts()(L233). So for the first 128 steps we are just collecting rollouts with gradients disabled, and the shape is[1,1,64]and training mode isFalse, as confirmed by the output above. So far, so good.But then we hit
OnPolicyAlgorithm.train()(SB3 L337), resp. L345 inRecurrentPPO, which sends us toRecurrentRolloutBuffer.get()L147.Here is where my understanding starts to falter. The code from L184 onwards employs a 'shuffling trick' for minibatch sampling. Fast-forward to the
yieldstatement on L196 ->_get_samples()on L199 ->create_sequencersL206, we get to [L82] (I've pasted lines 81-89 here for convenience):Because of
seq_start[0] = True(actually, because first element ofenv_changewas already marked as1.0, this is probably redundant),seq_startnow contains two elements that areTrue(the originalsplit_indexand index0). So this means thatseq_start_indicesis of length2, and sincen_seqis assigned to the length ofseq_start_indiceson L354, it ends up evaluating to2, and this propagates all the way to the second dimension of the LSTM actor (and critic).So my questions are:
RecurrentRolloutBuffer.get()? It would be good to add some logic to disable it on demand.create_sequencers()working as intended?create_sequencers()function in_get_samples()justified when we are using a single environment?1289onwards in the output above)? I understand how we end up with two (Also, based on the comment on L180, I think the intention is to have only two), but I'm not sure how this would result in three with a single environment.Thank you, and apologies about this essay, this question ended up a lot longer than I intended it to be.
Checklist