diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index fb374220a97f..fd6ee6cb71d2 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -421,6 +421,7 @@ def forward( output_hidden_states=False, return_dict=True, cache_position=None, + position_ids=None, ): if self.gradient_checkpointing and self.training: if use_cache: @@ -436,7 +437,15 @@ def forward( past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] - sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :] + sinusoidal_pos = self.embed_positions( + hidden_states.shape[:-1], past_key_values_length, position_ids=position_ids + ) + if sinusoidal_pos.dim() == 2: + # position_ids was None or 1D: output is [seq_len, embed_dim] + sinusoidal_pos = sinusoidal_pos[None, None, :, :] + else: + # position_ids was 2D: output is [batch, seq_len, embed_dim] + sinusoidal_pos = sinusoidal_pos[:, None, :, :] for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -682,6 +691,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_hidden_states: torch.FloatTensor | None = None, encoder_attention_mask: torch.FloatTensor | None = None, @@ -763,6 +773,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + position_ids=position_ids, ) sequence_output = encoder_outputs[0] @@ -898,6 +909,7 @@ def forward( input_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_hidden_states: torch.FloatTensor | None = None, encoder_attention_mask: torch.FloatTensor | None = None, @@ -939,6 +951,7 @@ def forward( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, + position_ids=position_ids, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask,