Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,7 @@ def forward(self, hidden_states: torch.FloatTensor | None) -> torch.FloatTensor:
return hidden_states


# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
class CodeGenBlock(GradientCheckpointingLayer):
# Ignore copy
def __init__(self, config, layer_idx=None):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
Expand Down
146 changes: 38 additions & 108 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, logging
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from .configuration_gptj import GPTJConfig


Expand Down Expand Up @@ -172,13 +175,8 @@ def forward(
attention_mask: torch.FloatTensor | None = None,
position_ids: torch.LongTensor | None = None,
use_cache: bool | None = False,
output_attentions: bool | None = False,
cache_position: torch.LongTensor | None = None,
) -> (
tuple[torch.Tensor, tuple[torch.Tensor]]
| tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]
| None
):
) -> tuple[torch.Tensor, torch.Tensor]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
Expand Down Expand Up @@ -253,13 +251,8 @@ def forward(
attention_mask: torch.FloatTensor | None = None,
position_ids: torch.LongTensor | None = None,
use_cache: bool | None = False,
output_attentions: bool | None = False,
cache_position: torch.LongTensor | None = None,
) -> (
tuple[torch.Tensor, tuple[torch.Tensor]]
| tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]
| None
):
) -> tuple[torch.Tensor, torch.Tensor]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
Expand Down Expand Up @@ -406,24 +399,22 @@ def forward(
attention_mask: torch.FloatTensor | None = None,
position_ids: torch.LongTensor | None = None,
use_cache: bool | None = False,
output_attentions: bool | None = False,
cache_position: torch.LongTensor | None = None,
) -> tuple[torch.Tensor] | tuple[torch.Tensor, tuple[torch.FloatTensor, ...]] | None:
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs, attn_weights = self.attn(
attn_outputs, _ = self.attn(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual

return hidden_states, attn_weights
return hidden_states


@auto_docstring
Expand All @@ -435,6 +426,10 @@ class GPTJPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_can_compile_fullgraph = True
_can_record_outputs = {
"hidden_states": GPTJBlock,
"attentions": GPTJAttention,
}

def _init_weights(self, module):
super()._init_weights(module)
Expand Down Expand Up @@ -465,6 +460,8 @@ def get_input_embeddings(self):
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings

@merge_with_config_defaults
@capture_outputs
@auto_docstring
def forward(
self,
Expand All @@ -475,35 +472,18 @@ def forward(
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> tuple | BaseModelOutputWithPast:
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

Expand Down Expand Up @@ -539,43 +519,22 @@ def forward(
hidden_states = self.drop(hidden_states)
output_shape = (-1, seq_length, hidden_states.size(-1))

all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
for block in self.h:
hidden_states = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)

hidden_states = outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[1],)

hidden_states = self.ln_f(hidden_states)

hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


Expand All @@ -595,6 +554,7 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

@can_return_tuple
@auto_docstring
def forward(
self,
Expand All @@ -606,13 +566,10 @@ def forward(
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> tuple | CausalLMOutputWithPast:
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
Expand All @@ -623,23 +580,19 @@ def forward(
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
outputs: BaseModelOutputWithPast = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = transformer_outputs[0]
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
Expand All @@ -648,16 +601,12 @@ def forward(
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


Expand Down Expand Up @@ -685,6 +634,7 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

@can_return_tuple
@auto_docstring
def forward(
self,
Expand All @@ -696,11 +646,8 @@ def forward(
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> tuple | SequenceClassifierOutputWithPast:
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
Expand All @@ -711,8 +658,6 @@ def forward(
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
Expand All @@ -721,11 +666,9 @@ def forward(
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
hidden_states = transformer_outputs[0]
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
Expand Down Expand Up @@ -774,9 +717,6 @@ def forward(
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
Expand All @@ -798,6 +738,7 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

@can_return_tuple
@auto_docstring
def forward(
self,
Expand All @@ -808,31 +749,24 @@ def forward(
inputs_embeds: torch.FloatTensor | None = None,
start_positions: torch.LongTensor | None = None,
end_positions: torch.LongTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> tuple | QuestionAnsweringModelOutput:
**kwargs: Unpack[TransformersKwargs],
) -> QuestionAnsweringModelOutput:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

sequence_output = outputs[0]
sequence_output = outputs.last_hidden_state

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
Expand All @@ -856,10 +790,6 @@ def forward(
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
Expand Down
Loading