mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Hello its breaking breakmodel time
This commit is contained in:
@@ -10,6 +10,7 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
modeling_utils,
|
||||
)
|
||||
from modeling.lazy_loader import LazyTensor
|
||||
|
||||
import utils
|
||||
|
||||
@@ -125,6 +126,173 @@ def patch_transformers_generation() -> None:
|
||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||
|
||||
|
||||
CURRENT_CHECKPOINT = None
|
||||
def patch_transformers_for_lazyload() -> None:
|
||||
import torch
|
||||
import inspect
|
||||
from accelerate.utils import set_module_tensor_to_device, offload_weight
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # left for now but could be removed, see below
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=None,
|
||||
offload_folder=None,
|
||||
offload_index=None,
|
||||
state_dict_folder=None,
|
||||
state_dict_index=None,
|
||||
dtype=None,
|
||||
load_in_8bit=False,
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
||||
params back to the normal device, but only for `loaded_state_dict_keys`.
|
||||
|
||||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||
`bert.pooler.dense.weight`
|
||||
|
||||
"""
|
||||
|
||||
print("DEVMAP", device_map)
|
||||
|
||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||
# - deepspeed zero 3 support
|
||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
||||
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
|
||||
# they won't get loaded.
|
||||
|
||||
if load_in_8bit:
|
||||
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
|
||||
|
||||
error_msgs = []
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
new_key = key.replace("gamma", "weight")
|
||||
if "beta" in key:
|
||||
new_key = key.replace("beta", "bias")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
|
||||
# BEGIN PATCH
|
||||
if isinstance(param, LazyTensor):
|
||||
print("Materializing", param_name)
|
||||
param = param.materialize()
|
||||
# END PATCH
|
||||
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
if (
|
||||
param_name not in loaded_state_dict_keys
|
||||
or param_name not in expected_keys
|
||||
):
|
||||
continue
|
||||
|
||||
if param_name.startswith(start_prefix):
|
||||
param_name = param_name[len(start_prefix) :]
|
||||
|
||||
module_name = param_name
|
||||
set_module_kwargs = {}
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
if dtype is not None and torch.is_floating_point(param):
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and any(
|
||||
module_to_keep_in_fp32 in param_name
|
||||
for module_to_keep_in_fp32 in keep_in_fp32_modules
|
||||
)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
# TODO: @sgugger replace this check with version check at the next `accelerate` release
|
||||
if "dtype" in list(
|
||||
inspect.signature(set_module_tensor_to_device).parameters
|
||||
):
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
||||
if dtype is None:
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if old_param is not None:
|
||||
param = param.to(old_param.dtype)
|
||||
|
||||
set_module_kwargs["value"] = param
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
# find next higher level module that is defined in device_map:
|
||||
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
|
||||
while len(module_name) > 0 and module_name not in device_map:
|
||||
module_name = ".".join(module_name.split(".")[:-1])
|
||||
if module_name == "" and "" not in device_map:
|
||||
# TODO: group all errors and raise at the end.
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
param_device = device_map[module_name]
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
offload_index = offload_weight(
|
||||
param, param_name, offload_folder, offload_index
|
||||
)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(
|
||||
param, param_name, state_dict_folder, state_dict_index
|
||||
)
|
||||
elif not load_in_8bit:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, **set_module_kwargs
|
||||
)
|
||||
else:
|
||||
if (
|
||||
param.dtype == torch.int8
|
||||
and param_name.replace("weight", "SCB") in state_dict.keys()
|
||||
):
|
||||
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
|
||||
else:
|
||||
fp16_statistics = None
|
||||
|
||||
if "SCB" not in param_name:
|
||||
set_module_8bit_tensor_to_device(
|
||||
model,
|
||||
param_name,
|
||||
param_device,
|
||||
value=param,
|
||||
fp16_statistics=fp16_statistics,
|
||||
)
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||
_load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
|
||||
def patch_transformers() -> None:
|
||||
patch_transformers_download()
|
||||
patch_transformers_loader()
|
||||
|
Reference in New Issue
Block a user