diff --git a/aiserver.py b/aiserver.py index ad1efdab..390d1979 100644 --- a/aiserver.py +++ b/aiserver.py @@ -27,12 +27,12 @@ from ansi2html import Ansi2HTMLConverter logging.getLogger("urllib3").setLevel(logging.ERROR) -import attention_bias -attention_bias.do_patches() - from modeling import patches patches.patch_transformers_for_lazyload() +import attention_bias +attention_bias.do_patches() + from os import path, getcwd import time import re diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index cc3b83c1..b00132be 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -9,6 +9,7 @@ import functools import itertools import traceback import contextlib +from accelerate.big_modeling import load_checkpoint_and_dispatch from accelerate.utils.modeling import infer_auto_device_map, load_checkpoint_in_model from tqdm.auto import tqdm from typing import Dict, List, Optional, Union @@ -263,6 +264,9 @@ class HFTorchInferenceModel(HFInferenceModel): tf_kwargs["revision"] = utils.koboldai_vars.revision tf_kwargs["cache_dir"] = "cache" + if self.lazy_load: + tf_kwargs.pop("low_cpu_mem_usage", None) + # If we have model hints for legacy model, use them rather than fall back. try: if self.model_name == "GPT2Custom": @@ -285,17 +289,25 @@ class HFTorchInferenceModel(HFInferenceModel): # offload_state_dict=True # ) # model.tie_weights() + no_split_module_classes = ["GPTJBlock", "OPTDecoderLayer"] + print("[HUGE SKELETON] MAKING DEVICE MAP") device_map = infer_auto_device_map( model, max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"}, - no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"], + no_split_module_classes=no_split_module_classes, dtype="float16", ) + print("[HUGE SKELETON] TYING WEIGHTS") + model.tie_weights() + + print("[HUGE SKELETON] LOADING FROM PRETRAINED") return AutoModelForCausalLM.from_pretrained( - location, device_map=device_map - ) # , **tf_kwargs) + location, + device_map=device_map, + **tf_kwargs, + ) except Exception as e: traceback_string = traceback.format_exc().lower() diff --git a/modeling/patches.py b/modeling/patches.py index 8cc436b5..79aefe9d 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -13,6 +13,7 @@ from transformers import ( from modeling.lazy_loader import LazyTensor import utils +from logger import logger def patch_transformers_download(): @@ -127,8 +128,27 @@ def patch_transformers_generation() -> None: def patch_transformers_for_lazyload() -> None: + """ + Most of the code is modified code from the Accelerate and Transformers + projects, made by HuggingFace. The license for these projects are as follows: + --- + Copyright The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ import torch - import inspect + import accelerate + from accelerate.utils.modeling import named_module_tensors from accelerate.utils import set_module_tensor_to_device, offload_weight def _load_state_dict_into_meta_model( @@ -183,10 +203,9 @@ def patch_transformers_for_lazyload() -> None: for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) -# BEGIN PATCH + # BEGIN PATCH # TODO: Based on config dtype = torch.float16 - set_module_kwargs = {"dtype": dtype} for param_name, param in sorted( state_dict.items(), @@ -202,7 +221,7 @@ def patch_transformers_for_lazyload() -> None: if isinstance(param, LazyTensor): # Should always be true param = param.materialize() -# END PATCH + # END PATCH # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if ( @@ -228,13 +247,6 @@ def patch_transformers_for_lazyload() -> None: and dtype == torch.float16 ): param = param.to(torch.float32) - - # For backward compatibility with older versions of `accelerate` - # TODO: @sgugger replace this check with version check at the next `accelerate` release - if "dtype" in list( - inspect.signature(set_module_tensor_to_device).parameters - ): - set_module_kwargs["dtype"] = torch.float32 else: param = param.to(dtype) @@ -250,8 +262,6 @@ def patch_transformers_for_lazyload() -> None: if old_param is not None: param = param.to(old_param.dtype) - set_module_kwargs["value"] = param - if device_map is None: param_device = "cpu" else: @@ -263,6 +273,7 @@ def patch_transformers_for_lazyload() -> None: # TODO: group all errors and raise at the end. raise ValueError(f"{param_name} doesn't have any device set.") param_device = device_map[module_name] + if param_device == "disk": if not is_safetensors: offload_index = offload_weight( @@ -275,7 +286,10 @@ def patch_transformers_for_lazyload() -> None: elif not load_in_8bit: # For backward compatibility with older versions of `accelerate` set_module_tensor_to_device( - model, tensor_name=param_name, device=param_device, **set_module_kwargs + model, + tensor_name=param_name, + device=param_device, + value=param, ) else: if ( @@ -301,6 +315,73 @@ def patch_transformers_for_lazyload() -> None: _load_state_dict_into_meta_model ) + # Patch AlignDevicesHook to hack around OPT lm_head + HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"] + + def _recursed_key_value(module, key): + """Gets a tensor from a recursive key (ie with .s)""" + if "." in key: + splits = key.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + key = splits[-1] + return getattr(module, key) + + def _init_hook(self, module): + if not self.offload and self.execution_device is not None: + # BEGIN PATCH + for name, tensor in named_module_tensors( + module, recurse=self.place_submodules + ): + try: + set_module_tensor_to_device(module, name, self.execution_device) + except ValueError: + # ValueError: weight is on the meta device, we need a `value` to put in on 0. + # bleuuuuuuuuuuuuuuuhhh + if name in HACK_ZERO_ON_FAIL_TENSORS: + logger.warning(f"Couldn't find value for weight {name}, zeroing.") + set_module_tensor_to_device( + module, + name, + self.execution_device, + value=torch.zeros(tensor.shape), + ) + # END PATCH + elif self.offload: + self.original_devices = { + name: param.device + for name, param in named_module_tensors( + module, recurse=self.place_submodules + ) + } + + if self.weights_map is None: + self.weights_map = { + name: param.to("cpu") + for name, param in named_module_tensors( + module, + include_buffers=self.offload_buffers, + recurse=self.place_submodules, + ) + } + + for name, _ in named_module_tensors( + module, + include_buffers=self.offload_buffers, + recurse=self.place_submodules, + ): + set_module_tensor_to_device(module, name, "meta") + + if not self.offload_buffers and self.execution_device is not None: + for name, _ in module.named_buffers(recurse=self.place_submodules): + set_module_tensor_to_device(module, name, self.execution_device) + return module + + accelerate.hooks.AlignDevicesHook.init_hook = _init_hook + def patch_transformers() -> None: patch_transformers_download()