-
Notifications
You must be signed in to change notification settings - Fork 32.8k
[bnb] Small improvements on utils #18646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ea155ed
a7731f7
bf59f9f
f5dc6ad
42c9df2
a84aaa7
27b0ef0
224b504
01a4c0c
23fe74a
c266e23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1747,6 +1747,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| quantization works well for values of magnitude ~5, but beyond that, there is a significant performance | ||
| penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models | ||
| (small models, fine-tuning). | ||
| no_load_in_8bit_modules (`List[str]`, *optional*, defaults to `None`): | ||
| An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as | ||
| Jukebox that has several heads in different places and not necessarly at the last position. | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| subfolder (`str`, *optional*, defaults to `""`): | ||
| In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can | ||
| specify the folder name here. | ||
|
|
@@ -1839,6 +1842,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| offload_state_dict = kwargs.pop("offload_state_dict", False) | ||
| load_in_8bit = kwargs.pop("load_in_8bit", False) | ||
| int8_threshold = kwargs.pop("int8_threshold", 6.0) | ||
| no_load_in_8bit_modules = kwargs.pop("no_load_in_8bit_modules", None) | ||
|
||
| subfolder = kwargs.pop("subfolder", "") | ||
| commit_hash = kwargs.pop("_commit_hash", None) | ||
|
|
||
|
|
@@ -2142,7 +2146,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| logger.info("Detected 8-bit loading: activating 8-bit loading for this model") | ||
|
|
||
| # We never convert lm_head or any last modules for numerical stability reasons | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| modules_to_not_convert = get_keys_to_not_convert(model) | ||
| if no_load_in_8bit_modules is None: | ||
| modules_to_not_convert = get_keys_to_not_convert(model) | ||
| else: | ||
| modules_to_not_convert = no_load_in_8bit_modules | ||
| model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) | ||
|
|
||
| if isinstance(device_map, str): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.