Not quite

This commit is contained in:
somebody
2023-05-28 14:57:45 -05:00
parent ed0728188a
commit ceaefa9f5e
4 changed files with 25 additions and 22 deletions

View File

@@ -238,7 +238,7 @@ class model_backend(HFTorchInferenceModel):
shutil.rmtree("cache/")
self.patch_embedding()
lazy_loader.post_load_cleanup()
self.model.tie_weights()
self.model.kai_model = self
utils.koboldai_vars.modeldim = self.get_hidden_size()

View File

@@ -289,7 +289,8 @@ class HFTorchInferenceModel(HFInferenceModel):
device_map = infer_auto_device_map(
model,
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
no_split_module_classes=["GPTJBlock"],
no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"],
dtype="float16",
)
return AutoModelForCausalLM.from_pretrained(

View File

@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
import contextlib
from functools import reduce
import time
import zipfile
import pickle
import torch
@@ -55,6 +56,8 @@ import _codecs
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch.storage import UntypedStorage
# Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet.
try:
@@ -65,21 +68,9 @@ except ModuleNotFoundError:
HAS_SAFETENSORS = False
import utils
from logger import logger
STORAGE_TYPE_MAP = {
torch.float64: torch.DoubleStorage,
torch.float32: torch.FloatStorage,
torch.float16: torch.HalfStorage,
torch.int64: torch.LongStorage,
torch.int32: torch.IntStorage,
torch.int16: torch.ShortStorage,
torch.int8: torch.CharStorage,
torch.uint8: torch.ByteStorage,
torch.bool: torch.BoolStorage,
torch.bfloat16: torch.BFloat16Storage,
}
# Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {}
@@ -205,8 +196,8 @@ class TorchLazyTensor(LazyTensor):
assert isinstance(checkpoint, zipfile.ZipFile)
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little"
storage = UntypedStorage.from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little", dtype=self.dtype
)
storage = torch.serialization._get_restore_location(map_location)(
@@ -421,6 +412,8 @@ def use_lazy_load(
yield False
return
begin_time = time.time()
try:
old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor
@@ -471,17 +464,23 @@ def use_lazy_load(
finally:
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load
post_load_cleanup()
logger.debug(
f"[lazy_load] Context closed in {round(time.time() - begin_time, 2)} seconds."
)
if dematerialized_modules:
init_empty_weights.__exit__(None, None, None)
def post_load_cleanup() -> None:
"""Close dangling file pointers and clear caches after the load is complete."""
global torch_checkpoint_file_handles
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
CheckpointChunkCache.clear()
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
CheckpointChunkCache.clear(unload_model=True)
for v in torch_checkpoint_file_handles.values():
v.close()
torch_checkpoint_file_handles = {}

View File

@@ -184,6 +184,10 @@ def patch_transformers_for_lazyload() -> None:
state_dict[new_key] = state_dict.pop(old_key)
# BEGIN PATCH
# TODO: Based on config
dtype = torch.float16
set_module_kwargs = {"dtype": dtype}
for param_name, param in sorted(
state_dict.items(),
# State dict must be ordered in this manner to make the caching in
@@ -211,7 +215,6 @@ def patch_transformers_for_lazyload() -> None:
param_name = param_name[len(start_prefix) :]
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
@@ -272,7 +275,7 @@ def patch_transformers_for_lazyload() -> None:
elif not load_in_8bit:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(
model, param_name, param_device, **set_module_kwargs
model, tensor_name=param_name, device=param_device, **set_module_kwargs
)
else:
if (