Files
KoboldAI-Client/modeling/patches.py
2023-05-28 12:25:31 -05:00

302 lines
11 KiB
Python

from __future__ import annotations
import copy
import requests
from typing import Iterable, List
from tqdm.auto import tqdm
import transformers
from transformers import (
PreTrainedModel,
modeling_utils,
)
from modeling.lazy_loader import LazyTensor
import utils
def patch_transformers_download():
def http_get(
url: str,
temp_file,
proxies=None,
resume_size=0,
headers=None,
file_name=None,
):
"""
Download remote file. Do not gobble up errors.
"""
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = f"bytes={resume_size}-"
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
transformers.utils.hub._raise_for_status(r)
content_length = r.headers.get("Content-Length")
total = (
resume_size + int(content_length) if content_length is not None else None
)
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
if url[-11:] != "config.json":
progress = tqdm.tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=total,
initial=resume_size,
desc=f"Downloading {file_name}"
if file_name is not None
else "Downloading",
file=utils.UIProgressBarFile(),
)
utils.koboldai_vars.status_message = "Download Model"
utils.koboldai_vars.total_download_chunks = total
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
if url[-11:] != "config.json":
progress.update(len(chunk))
utils.koboldai_vars.downloaded_chunks += len(chunk)
temp_file.write(chunk)
if url[-11:] != "config.json":
progress.close()
utils.koboldai_vars.status_message = ""
transformers.utils.hub.http_get = http_get
def patch_transformers_loader() -> None:
"""
Patch the Transformers loader to use aria2 and our shard tracking.
Universal for TPU/MTJ and Torch.
"""
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
utils.koboldai_vars.fp32_model = False
utils.num_shards = None
utils.current_shard = 0
utils.from_pretrained_model_name = pretrained_model_name_or_path
utils.from_pretrained_index_filename = None
utils.from_pretrained_kwargs = kwargs
utils.bar = None
if not utils.args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(
cls, pretrained_model_name_or_path, *model_args, **kwargs
)
if not hasattr(PreTrainedModel, "_kai_patched"):
PreTrainedModel.from_pretrained = new_from_pretrained
PreTrainedModel._kai_patched = True
if hasattr(modeling_utils, "get_checkpoint_shard_files"):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(
pretrained_model_name_or_path, index_filename, *args, **kwargs
):
utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
return old_get_checkpoint_shard_files(
pretrained_model_name_or_path, index_filename, *args, **kwargs
)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
def patch_transformers_generation() -> None:
# Not sure why this global is needed...
global transformers
# Allow bad words filter to ban <|endoftext|> token
import transformers.generation.logits_process
def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int):
return new_init.old_init(self, bad_words_ids, -1)
new_init.old_init = (
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__
)
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
CURRENT_CHECKPOINT = None
def patch_transformers_for_lazyload() -> None:
import torch
import inspect
from accelerate.utils import set_module_tensor_to_device, offload_weight
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
start_prefix,
expected_keys,
device_map=None,
offload_folder=None,
offload_index=None,
state_dict_folder=None,
state_dict_index=None,
dtype=None,
load_in_8bit=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
print("DEVMAP", device_map)
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
error_msgs = []
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
for param_name, param in state_dict.items():
# BEGIN PATCH
if isinstance(param, LazyTensor):
print(".", end="", flush=True)
param = param.materialize()
# END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if (
param_name not in loaded_state_dict_keys
or param_name not in expected_keys
):
continue
if param_name.startswith(start_prefix):
param_name = param_name[len(start_prefix) :]
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(
module_to_keep_in_fp32 in param_name
for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
):
param = param.to(torch.float32)
# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(
inspect.signature(set_module_tensor_to_device).parameters
):
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if dtype is None:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
param = param.to(old_param.dtype)
set_module_kwargs["value"] = param
if device_map is None:
param_device = "cpu"
else:
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(
param, param_name, offload_folder, offload_index
)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(
param, param_name, state_dict_folder, state_dict_index
)
elif not load_in_8bit:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(
model, param_name, param_device, **set_module_kwargs
)
else:
if (
param.dtype == torch.int8
and param_name.replace("weight", "SCB") in state_dict.keys()
):
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
if "SCB" not in param_name:
set_module_8bit_tensor_to_device(
model,
param_name,
param_device,
value=param,
fp16_statistics=fp16_statistics,
)
return error_msgs, offload_index, state_dict_index
transformers.modeling_utils._load_state_dict_into_meta_model = (
_load_state_dict_into_meta_model
)
def patch_transformers() -> None:
patch_transformers_download()
patch_transformers_loader()
# Doesn't do anything for TPU
patch_transformers_generation()