Hello its breaking breakmodel time

This commit is contained in:
somebody
2023-05-27 16:31:53 -05:00
parent 97d2a78899
commit 1546b9efaa
8 changed files with 236 additions and 1097 deletions

View File

@@ -10,6 +10,7 @@ from transformers import (
PreTrainedModel,
modeling_utils,
)
from modeling.lazy_loader import LazyTensor
import utils
@@ -125,6 +126,173 @@ def patch_transformers_generation() -> None:
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
CURRENT_CHECKPOINT = None
def patch_transformers_for_lazyload() -> None:
import torch
import inspect
from accelerate.utils import set_module_tensor_to_device, offload_weight
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
start_prefix,
expected_keys,
device_map=None,
offload_folder=None,
offload_index=None,
state_dict_folder=None,
state_dict_index=None,
dtype=None,
load_in_8bit=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
print("DEVMAP", device_map)
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
error_msgs = []
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
for param_name, param in state_dict.items():
# BEGIN PATCH
if isinstance(param, LazyTensor):
print("Materializing", param_name)
param = param.materialize()
# END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if (
param_name not in loaded_state_dict_keys
or param_name not in expected_keys
):
continue
if param_name.startswith(start_prefix):
param_name = param_name[len(start_prefix) :]
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(
module_to_keep_in_fp32 in param_name
for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
):
param = param.to(torch.float32)
# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(
inspect.signature(set_module_tensor_to_device).parameters
):
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if dtype is None:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
param = param.to(old_param.dtype)
set_module_kwargs["value"] = param
if device_map is None:
param_device = "cpu"
else:
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(
param, param_name, offload_folder, offload_index
)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(
param, param_name, state_dict_folder, state_dict_index
)
elif not load_in_8bit:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(
model, param_name, param_device, **set_module_kwargs
)
else:
if (
param.dtype == torch.int8
and param_name.replace("weight", "SCB") in state_dict.keys()
):
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
if "SCB" not in param_name:
set_module_8bit_tensor_to_device(
model,
param_name,
param_device,
value=param,
fp16_statistics=fp16_statistics,
)
return error_msgs, offload_index, state_dict_index
transformers.modeling_utils._load_state_dict_into_meta_model = (
_load_state_dict_into_meta_model
)
def patch_transformers() -> None:
patch_transformers_download()
patch_transformers_loader()