mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Not quite
This commit is contained in:
@@ -238,7 +238,7 @@ class model_backend(HFTorchInferenceModel):
|
||||
shutil.rmtree("cache/")
|
||||
|
||||
self.patch_embedding()
|
||||
lazy_loader.post_load_cleanup()
|
||||
self.model.tie_weights()
|
||||
|
||||
self.model.kai_model = self
|
||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||
|
@@ -289,7 +289,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
device_map = infer_auto_device_map(
|
||||
model,
|
||||
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
|
||||
no_split_module_classes=["GPTJBlock"],
|
||||
no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"],
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
return AutoModelForCausalLM.from_pretrained(
|
||||
|
@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import contextlib
|
||||
from functools import reduce
|
||||
import time
|
||||
import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
@@ -55,6 +56,8 @@ import _codecs
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch.storage import UntypedStorage
|
||||
|
||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||
# support it yet.
|
||||
try:
|
||||
@@ -65,21 +68,9 @@ except ModuleNotFoundError:
|
||||
HAS_SAFETENSORS = False
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
|
||||
STORAGE_TYPE_MAP = {
|
||||
torch.float64: torch.DoubleStorage,
|
||||
torch.float32: torch.FloatStorage,
|
||||
torch.float16: torch.HalfStorage,
|
||||
torch.int64: torch.LongStorage,
|
||||
torch.int32: torch.IntStorage,
|
||||
torch.int16: torch.ShortStorage,
|
||||
torch.int8: torch.CharStorage,
|
||||
torch.uint8: torch.ByteStorage,
|
||||
torch.bool: torch.BoolStorage,
|
||||
torch.bfloat16: torch.BFloat16Storage,
|
||||
}
|
||||
|
||||
# Storage of zipfile handles for each shard
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
||||
@@ -205,8 +196,8 @@ class TorchLazyTensor(LazyTensor):
|
||||
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||
|
||||
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
|
||||
CheckpointChunkCache.handle.read(nbytes), "little"
|
||||
storage = UntypedStorage.from_buffer(
|
||||
CheckpointChunkCache.handle.read(nbytes), "little", dtype=self.dtype
|
||||
)
|
||||
|
||||
storage = torch.serialization._get_restore_location(map_location)(
|
||||
@@ -421,6 +412,8 @@ def use_lazy_load(
|
||||
yield False
|
||||
return
|
||||
|
||||
begin_time = time.time()
|
||||
|
||||
try:
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
@@ -471,17 +464,23 @@ def use_lazy_load(
|
||||
finally:
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
|
||||
post_load_cleanup()
|
||||
logger.debug(
|
||||
f"[lazy_load] Context closed in {round(time.time() - begin_time, 2)} seconds."
|
||||
)
|
||||
|
||||
if dematerialized_modules:
|
||||
init_empty_weights.__exit__(None, None, None)
|
||||
|
||||
|
||||
def post_load_cleanup() -> None:
|
||||
"""Close dangling file pointers and clear caches after the load is complete."""
|
||||
global torch_checkpoint_file_handles
|
||||
|
||||
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
|
||||
CheckpointChunkCache.clear()
|
||||
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
|
||||
CheckpointChunkCache.clear(unload_model=True)
|
||||
|
||||
for v in torch_checkpoint_file_handles.values():
|
||||
v.close()
|
||||
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
@@ -184,6 +184,10 @@ def patch_transformers_for_lazyload() -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# BEGIN PATCH
|
||||
# TODO: Based on config
|
||||
dtype = torch.float16
|
||||
set_module_kwargs = {"dtype": dtype}
|
||||
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
# State dict must be ordered in this manner to make the caching in
|
||||
@@ -211,7 +215,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
param_name = param_name[len(start_prefix) :]
|
||||
|
||||
module_name = param_name
|
||||
set_module_kwargs = {}
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
@@ -272,7 +275,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
elif not load_in_8bit:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, **set_module_kwargs
|
||||
model, tensor_name=param_name, device=param_device, **set_module_kwargs
|
||||
)
|
||||
else:
|
||||
if (
|
||||
|
Reference in New Issue
Block a user