From ceaefa9f5ed932a230be950fe88dd8efa2472b88 Mon Sep 17 00:00:00 2001 From: somebody Date: Sun, 28 May 2023 14:57:45 -0500 Subject: [PATCH] Not quite --- .../generic_hf_torch/class.py | 2 +- modeling/inference_models/hf_torch.py | 3 +- modeling/lazy_loader.py | 35 +++++++++---------- modeling/patches.py | 7 ++-- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index d4c254b8..b45b002c 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -238,7 +238,7 @@ class model_backend(HFTorchInferenceModel): shutil.rmtree("cache/") self.patch_embedding() - lazy_loader.post_load_cleanup() + self.model.tie_weights() self.model.kai_model = self utils.koboldai_vars.modeldim = self.get_hidden_size() diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index abc3c54c..cc3b83c1 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -289,7 +289,8 @@ class HFTorchInferenceModel(HFInferenceModel): device_map = infer_auto_device_map( model, max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"}, - no_split_module_classes=["GPTJBlock"], + no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"], + dtype="float16", ) return AutoModelForCausalLM.from_pretrained( diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index 54cbb912..b6ea1623 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE. import contextlib from functools import reduce +import time import zipfile import pickle import torch @@ -55,6 +56,8 @@ import _codecs import os from typing import Any, Callable, Dict, Optional, Tuple, Type +from torch.storage import UntypedStorage + # Safetensors is a dependency for the local version, TPU/Colab doesn't # support it yet. try: @@ -65,21 +68,9 @@ except ModuleNotFoundError: HAS_SAFETENSORS = False import utils +from logger import logger -STORAGE_TYPE_MAP = { - torch.float64: torch.DoubleStorage, - torch.float32: torch.FloatStorage, - torch.float16: torch.HalfStorage, - torch.int64: torch.LongStorage, - torch.int32: torch.IntStorage, - torch.int16: torch.ShortStorage, - torch.int8: torch.CharStorage, - torch.uint8: torch.ByteStorage, - torch.bool: torch.BoolStorage, - torch.bfloat16: torch.BFloat16Storage, -} - # Storage of zipfile handles for each shard torch_checkpoint_file_handles = {} @@ -205,8 +196,8 @@ class TorchLazyTensor(LazyTensor): assert isinstance(checkpoint, zipfile.ZipFile) CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET) - storage = STORAGE_TYPE_MAP[self.dtype].from_buffer( - CheckpointChunkCache.handle.read(nbytes), "little" + storage = UntypedStorage.from_buffer( + CheckpointChunkCache.handle.read(nbytes), "little", dtype=self.dtype ) storage = torch.serialization._get_restore_location(map_location)( @@ -421,6 +412,8 @@ def use_lazy_load( yield False return + begin_time = time.time() + try: old_rebuild_tensor = torch._utils._rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor @@ -471,17 +464,23 @@ def use_lazy_load( finally: torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load + + post_load_cleanup() + logger.debug( + f"[lazy_load] Context closed in {round(time.time() - begin_time, 2)} seconds." + ) + if dematerialized_modules: init_empty_weights.__exit__(None, None, None) def post_load_cleanup() -> None: + """Close dangling file pointers and clear caches after the load is complete.""" global torch_checkpoint_file_handles - print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data) - CheckpointChunkCache.clear() + logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}") + CheckpointChunkCache.clear(unload_model=True) for v in torch_checkpoint_file_handles.values(): v.close() - torch_checkpoint_file_handles = {} diff --git a/modeling/patches.py b/modeling/patches.py index 5c0573f7..8cc436b5 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -184,6 +184,10 @@ def patch_transformers_for_lazyload() -> None: state_dict[new_key] = state_dict.pop(old_key) # BEGIN PATCH + # TODO: Based on config + dtype = torch.float16 + set_module_kwargs = {"dtype": dtype} + for param_name, param in sorted( state_dict.items(), # State dict must be ordered in this manner to make the caching in @@ -211,7 +215,6 @@ def patch_transformers_for_lazyload() -> None: param_name = param_name[len(start_prefix) :] module_name = param_name - set_module_kwargs = {} # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. @@ -272,7 +275,7 @@ def patch_transformers_for_lazyload() -> None: elif not load_in_8bit: # For backward compatibility with older versions of `accelerate` set_module_tensor_to_device( - model, param_name, param_device, **set_module_kwargs + model, tensor_name=param_name, device=param_device, **set_module_kwargs ) else: if (