mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Not quite
This commit is contained in:
@@ -238,7 +238,7 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
shutil.rmtree("cache/")
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
self.patch_embedding()
|
self.patch_embedding()
|
||||||
lazy_loader.post_load_cleanup()
|
self.model.tie_weights()
|
||||||
|
|
||||||
self.model.kai_model = self
|
self.model.kai_model = self
|
||||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
@@ -289,7 +289,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
device_map = infer_auto_device_map(
|
device_map = infer_auto_device_map(
|
||||||
model,
|
model,
|
||||||
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
|
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
|
||||||
no_split_module_classes=["GPTJBlock"],
|
no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"],
|
||||||
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
return AutoModelForCausalLM.from_pretrained(
|
return AutoModelForCausalLM.from_pretrained(
|
||||||
|
@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
import time
|
||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
@@ -55,6 +56,8 @@ import _codecs
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from torch.storage import UntypedStorage
|
||||||
|
|
||||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||||
# support it yet.
|
# support it yet.
|
||||||
try:
|
try:
|
||||||
@@ -65,21 +68,9 @@ except ModuleNotFoundError:
|
|||||||
HAS_SAFETENSORS = False
|
HAS_SAFETENSORS = False
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
|
||||||
STORAGE_TYPE_MAP = {
|
|
||||||
torch.float64: torch.DoubleStorage,
|
|
||||||
torch.float32: torch.FloatStorage,
|
|
||||||
torch.float16: torch.HalfStorage,
|
|
||||||
torch.int64: torch.LongStorage,
|
|
||||||
torch.int32: torch.IntStorage,
|
|
||||||
torch.int16: torch.ShortStorage,
|
|
||||||
torch.int8: torch.CharStorage,
|
|
||||||
torch.uint8: torch.ByteStorage,
|
|
||||||
torch.bool: torch.BoolStorage,
|
|
||||||
torch.bfloat16: torch.BFloat16Storage,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Storage of zipfile handles for each shard
|
# Storage of zipfile handles for each shard
|
||||||
torch_checkpoint_file_handles = {}
|
torch_checkpoint_file_handles = {}
|
||||||
|
|
||||||
@@ -205,8 +196,8 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
assert isinstance(checkpoint, zipfile.ZipFile)
|
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||||
|
|
||||||
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||||
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
|
storage = UntypedStorage.from_buffer(
|
||||||
CheckpointChunkCache.handle.read(nbytes), "little"
|
CheckpointChunkCache.handle.read(nbytes), "little", dtype=self.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = torch.serialization._get_restore_location(map_location)(
|
storage = torch.serialization._get_restore_location(map_location)(
|
||||||
@@ -421,6 +412,8 @@ def use_lazy_load(
|
|||||||
yield False
|
yield False
|
||||||
return
|
return
|
||||||
|
|
||||||
|
begin_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||||
@@ -471,17 +464,23 @@ def use_lazy_load(
|
|||||||
finally:
|
finally:
|
||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
|
|
||||||
|
post_load_cleanup()
|
||||||
|
logger.debug(
|
||||||
|
f"[lazy_load] Context closed in {round(time.time() - begin_time, 2)} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
init_empty_weights.__exit__(None, None, None)
|
init_empty_weights.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def post_load_cleanup() -> None:
|
def post_load_cleanup() -> None:
|
||||||
|
"""Close dangling file pointers and clear caches after the load is complete."""
|
||||||
global torch_checkpoint_file_handles
|
global torch_checkpoint_file_handles
|
||||||
|
|
||||||
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
|
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
|
||||||
CheckpointChunkCache.clear()
|
CheckpointChunkCache.clear(unload_model=True)
|
||||||
|
|
||||||
for v in torch_checkpoint_file_handles.values():
|
for v in torch_checkpoint_file_handles.values():
|
||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
torch_checkpoint_file_handles = {}
|
torch_checkpoint_file_handles = {}
|
||||||
|
@@ -184,6 +184,10 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
# BEGIN PATCH
|
# BEGIN PATCH
|
||||||
|
# TODO: Based on config
|
||||||
|
dtype = torch.float16
|
||||||
|
set_module_kwargs = {"dtype": dtype}
|
||||||
|
|
||||||
for param_name, param in sorted(
|
for param_name, param in sorted(
|
||||||
state_dict.items(),
|
state_dict.items(),
|
||||||
# State dict must be ordered in this manner to make the caching in
|
# State dict must be ordered in this manner to make the caching in
|
||||||
@@ -211,7 +215,6 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
param_name = param_name[len(start_prefix) :]
|
param_name = param_name[len(start_prefix) :]
|
||||||
|
|
||||||
module_name = param_name
|
module_name = param_name
|
||||||
set_module_kwargs = {}
|
|
||||||
|
|
||||||
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
||||||
# in int/uint/bool and not cast them.
|
# in int/uint/bool and not cast them.
|
||||||
@@ -272,7 +275,7 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
elif not load_in_8bit:
|
elif not load_in_8bit:
|
||||||
# For backward compatibility with older versions of `accelerate`
|
# For backward compatibility with older versions of `accelerate`
|
||||||
set_module_tensor_to_device(
|
set_module_tensor_to_device(
|
||||||
model, param_name, param_device, **set_module_kwargs
|
model, tensor_name=param_name, device=param_device, **set_module_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
|
Reference in New Issue
Block a user