mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix Transformers 4.30
This commit is contained in:
@@ -151,7 +151,7 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
def _load_state_dict_into_meta_model(
|
def _load_state_dict_into_meta_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
loaded_state_dict_keys, # left for now but could be removed, see below
|
loaded_state_dict_keys,
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=None,
|
device_map=None,
|
||||||
@@ -160,28 +160,17 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
state_dict_folder=None,
|
state_dict_folder=None,
|
||||||
state_dict_index=None,
|
state_dict_index=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
# PATCH: load_in_8bit was renamed to is_quantized in Transformers 4.30, keep
|
||||||
|
# both for short term compatibility
|
||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
|
is_quantized=False,
|
||||||
|
|
||||||
is_safetensors=False,
|
is_safetensors=False,
|
||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules=None,
|
||||||
):
|
):
|
||||||
"""
|
is_quantized = is_quantized or load_in_8bit
|
||||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
|
||||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
|
||||||
params back to the normal device, but only for `loaded_state_dict_keys`.
|
|
||||||
|
|
||||||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
if is_quantized:
|
||||||
`bert.pooler.dense.weight`
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
|
||||||
# - deepspeed zero 3 support
|
|
||||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
|
||||||
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
|
||||||
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
|
|
||||||
# they won't get loaded.
|
|
||||||
|
|
||||||
if load_in_8bit:
|
|
||||||
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
|
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
|
||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
@@ -280,7 +269,7 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
state_dict_index = offload_weight(
|
state_dict_index = offload_weight(
|
||||||
param, param_name, state_dict_folder, state_dict_index
|
param, param_name, state_dict_folder, state_dict_index
|
||||||
)
|
)
|
||||||
elif not load_in_8bit:
|
elif not is_quantized:
|
||||||
# For backward compatibility with older versions of `accelerate`
|
# For backward compatibility with older versions of `accelerate`
|
||||||
set_module_tensor_to_device(
|
set_module_tensor_to_device(
|
||||||
model,
|
model,
|
||||||
|
Reference in New Issue
Block a user