mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
More cleaning
This commit is contained in:
@@ -238,6 +238,7 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
shutil.rmtree("cache/")
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
self.patch_embedding()
|
self.patch_embedding()
|
||||||
|
lazy_loader.post_load_cleanup()
|
||||||
|
|
||||||
self.model.kai_model = self
|
self.model.kai_model = self
|
||||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
@@ -87,13 +87,29 @@ torch_checkpoint_file_handles = {}
|
|||||||
class CheckpointChunkCache:
|
class CheckpointChunkCache:
|
||||||
"""Storage for common checkpoint weight files to speed up loading. In order
|
"""Storage for common checkpoint weight files to speed up loading. In order
|
||||||
for this to be effective at all, weights must be loaded in ascending order
|
for this to be effective at all, weights must be loaded in ascending order
|
||||||
of (key, seek_offset)."""
|
of (key, seek_offset).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# There is considerable room for improvement here; we could peek into the
|
||||||
|
# state dict and preload the N most frequent weight files or something, but
|
||||||
|
# this first implementation is on par with the speed of whatever the
|
||||||
|
# previous callback did.
|
||||||
|
|
||||||
file_name = None
|
file_name = None
|
||||||
key = None
|
key = None
|
||||||
handle = None
|
handle = None
|
||||||
|
|
||||||
|
hit_data = {"hits": 0, "misses": 0}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear(cls) -> None:
|
def clear(cls, unload_model: bool = False) -> None:
|
||||||
|
if unload_model:
|
||||||
|
cls.hit_data["hits"] = 0
|
||||||
|
cls.hit_data["misses"] = 0
|
||||||
|
|
||||||
|
if cls.handle:
|
||||||
|
cls.handle.close()
|
||||||
|
|
||||||
cls.file_name = None
|
cls.file_name = None
|
||||||
cls.key = None
|
cls.key = None
|
||||||
cls.handle = None
|
cls.handle = None
|
||||||
@@ -140,20 +156,22 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
||||||
|
|
||||||
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
|
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
|
||||||
|
|
||||||
# Most of the operations are just seeks, let's see if we can optimize that.
|
# Often we are using the same weight file to store multiple tensors, so
|
||||||
|
# let's cache the file handle to maintain a seek position and other
|
||||||
|
# fast stuff.
|
||||||
if (
|
if (
|
||||||
CheckpointChunkCache.file_name != filename
|
CheckpointChunkCache.file_name != filename
|
||||||
or CheckpointChunkCache.key != self.key
|
or CheckpointChunkCache.key != self.key
|
||||||
or not CheckpointChunkCache.handle
|
or not CheckpointChunkCache.handle
|
||||||
):
|
):
|
||||||
# Flush cache if invalid
|
# Cache miss. Assuming weights are loaded in order of
|
||||||
|
# (key, seek_offset), this means we need to invalidate the cache.
|
||||||
print("!", end="", flush=True)
|
print("!", end="", flush=True)
|
||||||
|
CheckpointChunkCache.hit_data["misses"] += 1
|
||||||
|
|
||||||
if CheckpointChunkCache.handle:
|
CheckpointChunkCache.clear()
|
||||||
CheckpointChunkCache.handle.close()
|
|
||||||
|
|
||||||
CheckpointChunkCache.file_name = filename
|
CheckpointChunkCache.file_name = filename
|
||||||
CheckpointChunkCache.key = self.key
|
CheckpointChunkCache.key = self.key
|
||||||
@@ -165,6 +183,10 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
CheckpointChunkCache.handle = checkpoint.open(
|
CheckpointChunkCache.handle = checkpoint.open(
|
||||||
f"{filename}/data/{self.key}", "r"
|
f"{filename}/data/{self.key}", "r"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Cache hit. Hip hip hooray! :^)
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
CheckpointChunkCache.hit_data["hits"] += 1
|
||||||
|
|
||||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||||
dtype = self.dtype
|
dtype = self.dtype
|
||||||
@@ -173,7 +195,9 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
if dtype is torch.bool
|
if dtype is torch.bool
|
||||||
else size
|
else size
|
||||||
* (
|
* (
|
||||||
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits
|
(torch.finfo if self.dtype.is_floating_point else torch.iinfo)(
|
||||||
|
self.dtype
|
||||||
|
).bits
|
||||||
>> 3
|
>> 3
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -181,14 +205,14 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
assert isinstance(checkpoint, zipfile.ZipFile)
|
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||||
|
|
||||||
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||||
storage = STORAGE_TYPE_MAP[dtype].from_buffer(
|
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
|
||||||
CheckpointChunkCache.handle.read(nbytes), "little"
|
CheckpointChunkCache.handle.read(nbytes), "little"
|
||||||
)
|
)
|
||||||
|
|
||||||
storage = torch.serialization._get_restore_location(map_location)(
|
storage = torch.serialization._get_restore_location(map_location)(
|
||||||
storage, self.location
|
storage, self.location
|
||||||
)
|
)
|
||||||
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
tensor = torch.tensor([], dtype=self.dtype, device=storage.device)
|
||||||
tensor.set_(storage, 0, self.shape, self.stride)
|
tensor.set_(storage, 0, self.shape, self.stride)
|
||||||
tensor.requires_grad = not no_grad and self.requires_grad
|
tensor.requires_grad = not no_grad and self.requires_grad
|
||||||
tensor._backward_hooks = self.backward_hooks
|
tensor._backward_hooks = self.backward_hooks
|
||||||
@@ -449,3 +473,15 @@ def use_lazy_load(
|
|||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
init_empty_weights.__exit__(None, None, None)
|
init_empty_weights.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def post_load_cleanup() -> None:
|
||||||
|
global torch_checkpoint_file_handles
|
||||||
|
|
||||||
|
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
|
||||||
|
CheckpointChunkCache.clear()
|
||||||
|
|
||||||
|
for v in torch_checkpoint_file_handles.values():
|
||||||
|
v.close()
|
||||||
|
|
||||||
|
torch_checkpoint_file_handles = {}
|
||||||
|
@@ -196,7 +196,7 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
):
|
):
|
||||||
|
|
||||||
if isinstance(param, LazyTensor):
|
if isinstance(param, LazyTensor):
|
||||||
print(".", end="", flush=True)
|
# Should always be true
|
||||||
param = param.materialize()
|
param = param.materialize()
|
||||||
# END PATCH
|
# END PATCH
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user