Fix Transformers 4.30

This commit is contained in:
somebody
2023-06-21 16:40:12 -05:00
parent c56214c275
commit 70f113141c

View File

@@ -151,7 +151,7 @@ def patch_transformers_for_lazyload() -> None:
def _load_state_dict_into_meta_model( def _load_state_dict_into_meta_model(
model, model,
state_dict, state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below loaded_state_dict_keys,
start_prefix, start_prefix,
expected_keys, expected_keys,
device_map=None, device_map=None,
@@ -160,28 +160,17 @@ def patch_transformers_for_lazyload() -> None:
state_dict_folder=None, state_dict_folder=None,
state_dict_index=None, state_dict_index=None,
dtype=None, dtype=None,
# PATCH: load_in_8bit was renamed to is_quantized in Transformers 4.30, keep
# both for short term compatibility
load_in_8bit=False, load_in_8bit=False,
is_quantized=False,
is_safetensors=False, is_safetensors=False,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
): ):
""" is_quantized = is_quantized or load_in_8bit
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in if is_quantized:
`bert.pooler.dense.weight`
"""
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device from .utils.bitsandbytes import set_module_8bit_tensor_to_device
error_msgs = [] error_msgs = []
@@ -280,7 +269,7 @@ def patch_transformers_for_lazyload() -> None:
state_dict_index = offload_weight( state_dict_index = offload_weight(
param, param_name, state_dict_folder, state_dict_index param, param_name, state_dict_folder, state_dict_index
) )
elif not load_in_8bit: elif not is_quantized:
# For backward compatibility with older versions of `accelerate` # For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device( set_module_tensor_to_device(
model, model,