diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 7dfb1ee749ec..3e7fc9a7f6d8 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -261,6 +261,7 @@ def forward( input_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, inputs_embeds: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, @@ -396,6 +397,7 @@ def forward( input_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, use_cache: bool | None = None, @@ -429,6 +431,7 @@ def forward( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions,