from __future__ import annotations import copy import requests from typing import Iterable, List from tqdm.auto import tqdm import transformers from transformers import ( PreTrainedModel, modeling_utils, ) from modeling.lazy_loader import LazyTensor import utils from logger import logger def patch_transformers_download(): def http_get( url: str, temp_file, proxies=None, resume_size=0, headers=None, file_name=None, ): """ Download remote file. Do not gobble up errors. """ headers = copy.deepcopy(headers) if resume_size > 0: headers["Range"] = f"bytes={resume_size}-" r = requests.get(url, stream=True, proxies=proxies, headers=headers) transformers.utils.hub._raise_for_status(r) content_length = r.headers.get("Content-Length") total = ( resume_size + int(content_length) if content_length is not None else None ) # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` # and can be set using `utils.logging.enable/disable_progress_bar()` if url[-11:] != "config.json": progress = tqdm.tqdm( unit="B", unit_scale=True, unit_divisor=1024, total=total, initial=resume_size, desc=f"Downloading {file_name}" if file_name is not None else "Downloading", file=utils.UIProgressBarFile(), ) utils.koboldai_vars.status_message = "Download Model" utils.koboldai_vars.total_download_chunks = total for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks if url[-11:] != "config.json": progress.update(len(chunk)) utils.koboldai_vars.downloaded_chunks += len(chunk) temp_file.write(chunk) if url[-11:] != "config.json": progress.close() utils.koboldai_vars.status_message = "" transformers.utils.hub.http_get = http_get def patch_transformers_loader() -> None: """ Patch the Transformers loader to use aria2 and our shard tracking. Universal for TPU/MTJ and Torch. """ old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): utils.koboldai_vars.fp32_model = False utils.num_shards = None utils.current_shard = 0 utils.from_pretrained_model_name = pretrained_model_name_or_path utils.from_pretrained_index_filename = None utils.from_pretrained_kwargs = kwargs utils.bar = None if not utils.args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained( cls, pretrained_model_name_or_path, *model_args, **kwargs ) if not hasattr(PreTrainedModel, "_kai_patched"): PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel._kai_patched = True if hasattr(modeling_utils, "get_checkpoint_shard_files"): old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files def new_get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, *args, **kwargs ): utils.num_shards = utils.get_num_shards(index_filename) utils.from_pretrained_index_filename = index_filename return old_get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, *args, **kwargs ) modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files def patch_transformers_generation() -> None: # Not sure why this global is needed... global transformers # Allow bad words filter to ban <|endoftext|> token import transformers.generation.logits_process def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int): return new_init.old_init(self, bad_words_ids, -1) new_init.old_init = ( transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ ) transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init def patch_transformers_for_lazyload() -> None: """ Most of the code is modified code from the Accelerate and Transformers projects, made by HuggingFace. The license for these projects are as follows: --- Copyright The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import accelerate from accelerate.utils.modeling import named_module_tensors from accelerate.utils import set_module_tensor_to_device, offload_weight def _load_state_dict_into_meta_model( model, state_dict, loaded_state_dict_keys, # left for now but could be removed, see below start_prefix, expected_keys, device_map=None, offload_folder=None, offload_index=None, state_dict_folder=None, state_dict_index=None, dtype=None, load_in_8bit=False, is_safetensors=False, keep_in_fp32_modules=None, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the params back to the normal device, but only for `loaded_state_dict_keys`. `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in `bert.pooler.dense.weight` """ # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model # - deepspeed zero 3 support # - need to copy metadata if any - see _load_state_dict_into_model # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() # - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case # they won't get loaded. if load_in_8bit: from .utils.bitsandbytes import set_module_8bit_tensor_to_device error_msgs = [] old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key: new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # BEGIN PATCH # TODO: Based on config dtype = torch.float16 for param_name, param in sorted( state_dict.items(), # State dict must be ordered in this manner to make the caching in # lazy_loader.py effective key=lambda x: ( # NOTE: Assuming key is just decimal int(x[1].key), x[1].seek_offset, ), ): if isinstance(param, LazyTensor): # Should always be true param = param.materialize() # END PATCH # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if ( param_name not in loaded_state_dict_keys or param_name not in expected_keys ): continue if param_name.startswith(start_prefix): param_name = param_name[len(start_prefix) :] module_name = param_name # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. if dtype is not None and torch.is_floating_point(param): if ( keep_in_fp32_modules is not None and any( module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules ) and dtype == torch.float16 ): param = param.to(torch.float32) else: param = param.to(dtype) # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model if dtype is None: old_param = model splits = param_name.split(".") for split in splits: old_param = getattr(old_param, split) if old_param is None: break if old_param is not None: param = param.to(old_param.dtype) if device_map is None: param_device = "cpu" else: # find next higher level module that is defined in device_map: # bert.lm_head.weight -> bert.lm_head -> bert -> '' while len(module_name) > 0 and module_name not in device_map: module_name = ".".join(module_name.split(".")[:-1]) if module_name == "" and "" not in device_map: # TODO: group all errors and raise at the end. raise ValueError(f"{param_name} doesn't have any device set.") param_device = device_map[module_name] if param_device == "disk": if not is_safetensors: offload_index = offload_weight( param, param_name, offload_folder, offload_index ) elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight( param, param_name, state_dict_folder, state_dict_index ) elif not load_in_8bit: # For backward compatibility with older versions of `accelerate` set_module_tensor_to_device( model, tensor_name=param_name, device=param_device, value=param, ) else: if ( param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys() ): fp16_statistics = state_dict[param_name.replace("weight", "SCB")] else: fp16_statistics = None if "SCB" not in param_name: set_module_8bit_tensor_to_device( model, param_name, param_device, value=param, fp16_statistics=fp16_statistics, ) return error_msgs, offload_index, state_dict_index transformers.modeling_utils._load_state_dict_into_meta_model = ( _load_state_dict_into_meta_model ) # Patch AlignDevicesHook to hack around OPT lm_head HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"] def _init_hook(self, module): if not self.offload and self.execution_device is not None: # BEGIN PATCH for name, tensor in named_module_tensors( module, recurse=self.place_submodules ): try: set_module_tensor_to_device(module, name, self.execution_device) except ValueError: # ValueError: weight is on the meta device, we need a `value` to put in on 0. # bleuuuuuuuuuuuuuuuhhh if name in HACK_ZERO_ON_FAIL_TENSORS: logger.warning(f"Couldn't find value for weight {name}, zeroing.") set_module_tensor_to_device( module, name, self.execution_device, value=torch.zeros(tensor.shape), ) # END PATCH elif self.offload: self.original_devices = { name: param.device for name, param in named_module_tensors( module, recurse=self.place_submodules ) } if self.weights_map is None: self.weights_map = { name: param.to("cpu") for name, param in named_module_tensors( module, include_buffers=self.offload_buffers, recurse=self.place_submodules, ) } for name, _ in named_module_tensors( module, include_buffers=self.offload_buffers, recurse=self.place_submodules, ): set_module_tensor_to_device(module, name, "meta") if not self.offload_buffers and self.execution_device is not None: for name, _ in module.named_buffers(recurse=self.place_submodules): set_module_tensor_to_device(module, name, self.execution_device) return module accelerate.hooks.AlignDevicesHook.init_hook = _init_hook def patch_transformers() -> None: patch_transformers_download() patch_transformers_loader() # Doesn't do anything for TPU patch_transformers_generation()