Skip to content
9 changes: 8 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
Jukebox that has several heads in different places and not necessarly at the last position.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Expand Down Expand Up @@ -1839,6 +1842,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
no_load_in_8bit_modules = kwargs.pop("no_load_in_8bit_modules", None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to have this be a class variable of PreTrainedModel (like the no_split variable used for big model inference)? I'm afraid the user won't know what to set this too and it looks like it's something we should automatically handle?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on that but this argument is optional because the function get_keys_not_to_convert should automatically take care of that except for some models like Jukebox where it is a bit trickier due to its architecture.
In this case the user will just have to manually set which modules should be kept in their native precision and specify them in the kwargs, so I feel like it is a bit easier than having it as an argument of PretrainedModel because you would need to open a PR to add the feature.

subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

Expand Down Expand Up @@ -2142,7 +2146,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")

# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_keys_to_not_convert(model)
if no_load_in_8bit_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = no_load_in_8bit_modules
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)

if isinstance(device_map, str):
Expand Down