From 70f113141c533b142ed60aefb7557cdcde34ffd2 Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 21 Jun 2023 16:40:12 -0500 Subject: [PATCH] Fix Transformers 4.30 --- modeling/patches.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/modeling/patches.py b/modeling/patches.py index 71319ef8..83d157b4 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -151,7 +151,7 @@ def patch_transformers_for_lazyload() -> None: def _load_state_dict_into_meta_model( model, state_dict, - loaded_state_dict_keys, # left for now but could be removed, see below + loaded_state_dict_keys, start_prefix, expected_keys, device_map=None, @@ -160,28 +160,17 @@ def patch_transformers_for_lazyload() -> None: state_dict_folder=None, state_dict_index=None, dtype=None, + # PATCH: load_in_8bit was renamed to is_quantized in Transformers 4.30, keep + # both for short term compatibility load_in_8bit=False, + is_quantized=False, + is_safetensors=False, keep_in_fp32_modules=None, ): - """ - This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its - params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the - params back to the normal device, but only for `loaded_state_dict_keys`. + is_quantized = is_quantized or load_in_8bit - `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in - `bert.pooler.dense.weight` - - """ - - # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model - # - deepspeed zero 3 support - # - need to copy metadata if any - see _load_state_dict_into_model - # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() - # - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case - # they won't get loaded. - - if load_in_8bit: + if is_quantized: from .utils.bitsandbytes import set_module_8bit_tensor_to_device error_msgs = [] @@ -280,7 +269,7 @@ def patch_transformers_for_lazyload() -> None: state_dict_index = offload_weight( param, param_name, state_dict_folder, state_dict_index ) - elif not load_in_8bit: + elif not is_quantized: # For backward compatibility with older versions of `accelerate` set_module_tensor_to_device( model,