Add explicit position_ids to GPT-Neo attention layers#44687
Add explicit position_ids to GPT-Neo attention layers#44687vxa8502 wants to merge 1 commit intohuggingface:mainfrom
Conversation
saivedant169
left a comment
There was a problem hiding this comment.
Thanks for working on this — same issue I tackled for RoFormer/Bloom/MPT in #44705-7.
A couple of observations from my experience with the other models:
-
GPT-Neo uses learned absolute embeddings, not RoPE — so
position_idsneeds to actually be consumed somewhere in the position embedding layer for this to have functional impact. In the current diff,position_idsis threaded through to the attention function viaALL_ATTENTION_FUNCTIONS, but does the embedding layer (GPTNeoModel.wpe) use it? If not, the parameter is accepted but silently ignored, which is fine for API consistency (same as the Bloom/MPT approach) but worth noting in the PR description. -
The
**kwargsadditions (lines 294, 344, 493) look like a separate concern fromposition_ids. Was this needed to pass through some other arguments, or was it introduced to forwardposition_idsspecifically? If it's unrelated, splitting it out would make the diff cleaner for reviewers. -
Test results — did you run
test_for_generate_causal_lm? When I addedposition_idsto RoFormer, the 2D shape fromGenerationMixincaused a shape mismatch that needed handling. Worth confirming GPT-Neo's generation path works.
|
@saivedant169 Thanks for the thorough review. Addressing each point:
Pushing updated changes shortly. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gpt_neo |
Fixes partial #32937
Adds explicit
position_idsthreading through GPT-Neo's attention layers to enable flash attention's packed sequence optimization.Context
GPT-Neo uses learned absolute position embeddings (
wpe) applied at the model level, unlike RoPE models (Llama, Falcon) that apply rotations inside attention. Theposition_idsparameter passed to attention layers serves two purposes:_flash_attention_forward()to detect packed sequences via_is_packed_sequence()(batch_size=1 edge case)Changes
Add
position_idsparameter to:GPTNeoSelfAttention.forward()GPTNeoFlashAttention2.forward()GPTNeoAttention.forward()GPTNeoBlock.forward()Pass
position_idsto_flash_attention_forward()call.Tests
test_generate_with_and_without_position_ids