diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index ccc540207789..419e1ed0f86a 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -228,9 +228,7 @@ def forward(self, hidden_states: torch.FloatTensor | None) -> torch.FloatTensor: return hidden_states -# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen class CodeGenBlock(GradientCheckpointingLayer): - # Ignore copy def __init__(self, config, layer_idx=None): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 007cc6fd9822..e86219124887 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -33,7 +33,10 @@ SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs from .configuration_gptj import GPTJConfig @@ -172,13 +175,8 @@ def forward( attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, - output_attentions: bool | None = False, cache_position: torch.LongTensor | None = None, - ) -> ( - tuple[torch.Tensor, tuple[torch.Tensor]] - | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]] - | None - ): + ) -> tuple[torch.Tensor, torch.Tensor]: query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) @@ -253,13 +251,8 @@ def forward( attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, - output_attentions: bool | None = False, cache_position: torch.LongTensor | None = None, - ) -> ( - tuple[torch.Tensor, tuple[torch.Tensor]] - | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]] - | None - ): + ) -> tuple[torch.Tensor, torch.Tensor]: query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) @@ -406,24 +399,22 @@ def forward( attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, - output_attentions: bool | None = False, cache_position: torch.LongTensor | None = None, - ) -> tuple[torch.Tensor] | tuple[torch.Tensor, tuple[torch.FloatTensor, ...]] | None: + ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs, attn_weights = self.attn( + attn_outputs, _ = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, ) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_outputs + feed_forward_hidden_states + residual - return hidden_states, attn_weights + return hidden_states @auto_docstring @@ -435,6 +426,10 @@ class GPTJPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _can_compile_fullgraph = True + _can_record_outputs = { + "hidden_states": GPTJBlock, + "attentions": GPTJAttention, + } def _init_weights(self, module): super()._init_weights(module) @@ -465,6 +460,8 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + @merge_with_config_defaults + @capture_outputs @auto_docstring def forward( self, @@ -475,35 +472,18 @@ def forward( position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, - **kwargs, - ) -> tuple | BaseModelOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors than the model's internal embedding lookup matrix. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -539,43 +519,22 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block( + for block in self.h: + hidden_states = block( hidden_states, layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[1],) - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attentions, ) @@ -595,6 +554,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -606,13 +566,10 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, - **kwargs, - ) -> tuple | CausalLMOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This @@ -623,9 +580,7 @@ def forward( `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -633,13 +588,11 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, + **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -648,16 +601,12 @@ def forward( if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -685,6 +634,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -696,11 +646,8 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> tuple | SequenceClassifierOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutputWithPast: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This @@ -711,8 +658,6 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, @@ -721,11 +666,9 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -774,9 +717,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, @@ -798,6 +738,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -808,31 +749,24 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, start_positions: torch.LongTensor | None = None, end_positions: torch.LongTensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> tuple | QuestionAnsweringModelOutput: + **kwargs: Unpack[TransformersKwargs], + ) -> QuestionAnsweringModelOutput: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors than the model's internal embedding lookup matrix. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -856,10 +790,6 @@ def forward( end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits,