mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix Transformers 4.30
This commit is contained in:
@@ -151,7 +151,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # left for now but could be removed, see below
|
||||
loaded_state_dict_keys,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=None,
|
||||
@@ -160,28 +160,17 @@ def patch_transformers_for_lazyload() -> None:
|
||||
state_dict_folder=None,
|
||||
state_dict_index=None,
|
||||
dtype=None,
|
||||
# PATCH: load_in_8bit was renamed to is_quantized in Transformers 4.30, keep
|
||||
# both for short term compatibility
|
||||
load_in_8bit=False,
|
||||
is_quantized=False,
|
||||
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
||||
params back to the normal device, but only for `loaded_state_dict_keys`.
|
||||
is_quantized = is_quantized or load_in_8bit
|
||||
|
||||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||
`bert.pooler.dense.weight`
|
||||
|
||||
"""
|
||||
|
||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||
# - deepspeed zero 3 support
|
||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
||||
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
|
||||
# they won't get loaded.
|
||||
|
||||
if load_in_8bit:
|
||||
if is_quantized:
|
||||
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
|
||||
|
||||
error_msgs = []
|
||||
@@ -280,7 +269,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
state_dict_index = offload_weight(
|
||||
param, param_name, state_dict_folder, state_dict_index
|
||||
)
|
||||
elif not load_in_8bit:
|
||||
elif not is_quantized:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(
|
||||
model,
|
||||
|
Reference in New Issue
Block a user