mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
More cleaning
This commit is contained in:
@@ -238,6 +238,7 @@ class model_backend(HFTorchInferenceModel):
|
||||
shutil.rmtree("cache/")
|
||||
|
||||
self.patch_embedding()
|
||||
lazy_loader.post_load_cleanup()
|
||||
|
||||
self.model.kai_model = self
|
||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||
|
@@ -87,13 +87,29 @@ torch_checkpoint_file_handles = {}
|
||||
class CheckpointChunkCache:
|
||||
"""Storage for common checkpoint weight files to speed up loading. In order
|
||||
for this to be effective at all, weights must be loaded in ascending order
|
||||
of (key, seek_offset)."""
|
||||
of (key, seek_offset).
|
||||
"""
|
||||
|
||||
# There is considerable room for improvement here; we could peek into the
|
||||
# state dict and preload the N most frequent weight files or something, but
|
||||
# this first implementation is on par with the speed of whatever the
|
||||
# previous callback did.
|
||||
|
||||
file_name = None
|
||||
key = None
|
||||
handle = None
|
||||
|
||||
hit_data = {"hits": 0, "misses": 0}
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
def clear(cls, unload_model: bool = False) -> None:
|
||||
if unload_model:
|
||||
cls.hit_data["hits"] = 0
|
||||
cls.hit_data["misses"] = 0
|
||||
|
||||
if cls.handle:
|
||||
cls.handle.close()
|
||||
|
||||
cls.file_name = None
|
||||
cls.key = None
|
||||
cls.handle = None
|
||||
@@ -140,20 +156,22 @@ class TorchLazyTensor(LazyTensor):
|
||||
) -> torch.Tensor:
|
||||
|
||||
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
||||
|
||||
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
|
||||
|
||||
# Most of the operations are just seeks, let's see if we can optimize that.
|
||||
# Often we are using the same weight file to store multiple tensors, so
|
||||
# let's cache the file handle to maintain a seek position and other
|
||||
# fast stuff.
|
||||
if (
|
||||
CheckpointChunkCache.file_name != filename
|
||||
or CheckpointChunkCache.key != self.key
|
||||
or not CheckpointChunkCache.handle
|
||||
):
|
||||
# Flush cache if invalid
|
||||
# Cache miss. Assuming weights are loaded in order of
|
||||
# (key, seek_offset), this means we need to invalidate the cache.
|
||||
print("!", end="", flush=True)
|
||||
CheckpointChunkCache.hit_data["misses"] += 1
|
||||
|
||||
if CheckpointChunkCache.handle:
|
||||
CheckpointChunkCache.handle.close()
|
||||
CheckpointChunkCache.clear()
|
||||
|
||||
CheckpointChunkCache.file_name = filename
|
||||
CheckpointChunkCache.key = self.key
|
||||
@@ -165,6 +183,10 @@ class TorchLazyTensor(LazyTensor):
|
||||
CheckpointChunkCache.handle = checkpoint.open(
|
||||
f"{filename}/data/{self.key}", "r"
|
||||
)
|
||||
else:
|
||||
# Cache hit. Hip hip hooray! :^)
|
||||
print(".", end="", flush=True)
|
||||
CheckpointChunkCache.hit_data["hits"] += 1
|
||||
|
||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||
dtype = self.dtype
|
||||
@@ -173,7 +195,9 @@ class TorchLazyTensor(LazyTensor):
|
||||
if dtype is torch.bool
|
||||
else size
|
||||
* (
|
||||
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits
|
||||
(torch.finfo if self.dtype.is_floating_point else torch.iinfo)(
|
||||
self.dtype
|
||||
).bits
|
||||
>> 3
|
||||
)
|
||||
)
|
||||
@@ -181,14 +205,14 @@ class TorchLazyTensor(LazyTensor):
|
||||
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||
|
||||
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||
storage = STORAGE_TYPE_MAP[dtype].from_buffer(
|
||||
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
|
||||
CheckpointChunkCache.handle.read(nbytes), "little"
|
||||
)
|
||||
|
||||
storage = torch.serialization._get_restore_location(map_location)(
|
||||
storage, self.location
|
||||
)
|
||||
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
||||
tensor = torch.tensor([], dtype=self.dtype, device=storage.device)
|
||||
tensor.set_(storage, 0, self.shape, self.stride)
|
||||
tensor.requires_grad = not no_grad and self.requires_grad
|
||||
tensor._backward_hooks = self.backward_hooks
|
||||
@@ -449,3 +473,15 @@ def use_lazy_load(
|
||||
torch.load = old_torch_load
|
||||
if dematerialized_modules:
|
||||
init_empty_weights.__exit__(None, None, None)
|
||||
|
||||
|
||||
def post_load_cleanup() -> None:
|
||||
global torch_checkpoint_file_handles
|
||||
|
||||
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
|
||||
CheckpointChunkCache.clear()
|
||||
|
||||
for v in torch_checkpoint_file_handles.values():
|
||||
v.close()
|
||||
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
@@ -196,7 +196,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
):
|
||||
|
||||
if isinstance(param, LazyTensor):
|
||||
print(".", end="", flush=True)
|
||||
# Should always be true
|
||||
param = param.materialize()
|
||||
# END PATCH
|
||||
|
||||
|
Reference in New Issue
Block a user