Refactor gptj output tracing to use standardized decorators#44722
Open
chandan11248 wants to merge 3 commits intohuggingface:mainfrom
Open
Refactor gptj output tracing to use standardized decorators#44722chandan11248 wants to merge 3 commits intohuggingface:mainfrom
chandan11248 wants to merge 3 commits intohuggingface:mainfrom
Conversation
…tuple Migrate the GPT-J model to use the new standardized output collection decorators, replacing manual accumulation of hidden states and attention weights with hook-based capturing. Changes: - Add `_can_record_outputs` to `GPTJPreTrainedModel` mapping hidden_states to GPTJBlock and attentions to GPTJAttention - Add `@capture_outputs` and `@merge_with_config_defaults` to `GPTJModel.forward()` - Add `@can_return_tuple` to all task head models (ForCausalLM, ForSequenceClassification, ForQuestionAnswering) - Remove `output_attentions`, `output_hidden_states`, and `return_dict` parameters from all forward signatures - Remove manual accumulator loops and return_dict branching - Simplify GPTJBlock to return plain `torch.Tensor` instead of tuple - Update attention forward signatures to always return `(attn_output, attn_weights)` without conditional logic Resolves huggingface#43979
The CodeGenBlock is a documented copy of GPTJBlock. This syncs it to match the updated signature after removing output_attentions parameter and simplifying the return type to plain torch.Tensor. Generated via `python utils/check_copies.py --fix_and_overwrite`.
The previous commit auto-synced CodeGenBlock.forward() with the refactored GPTJBlock, but CodeGenModel still passes output_attentions to CodeGenBlock and expects a tuple return. Since the CodeGen model has not been refactored to use the new decorators yet, restore CodeGenBlock's original forward() signature and remove the '# Copied from' directive to decouple it from GPTJBlock until CodeGen gets its own output tracing refactor.
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: codegen, gptj |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Migrates the GPT-J model to use the new
@capture_outputsand@can_return_tupledecorators for standardized output collection, as described in #43979.Changes
_can_record_outputstoGPTJPreTrainedModel, mapping"hidden_states"→GPTJBlockand"attentions"→GPTJAttention@capture_outputsand@merge_with_config_defaultsdecorators toGPTJModel.forward()@can_return_tupledecorator toGPTJForCausalLM,GPTJForSequenceClassification, andGPTJForQuestionAnsweringoutput_attentions,output_hidden_states, andreturn_dictparameters from allforward()signaturesall_hidden_states,all_self_attentions) andreturn_dictbranching fromGPTJModel.forward()GPTJBlock.forward()to return a plaintorch.Tensorinstead of a tuple(attn_output, attn_weights)with a simplified type annotationNet result: 38 insertions, 108 deletions — cleaner architecture with no manual output collection boilerplate.
Before submitting
Fixes #43979 (gptj model)