More cleaning

This commit is contained in:
somebody
2023-05-28 13:22:32 -05:00
parent 14241fc156
commit ed0728188a
3 changed files with 48 additions and 11 deletions

View File

@@ -238,6 +238,7 @@ class model_backend(HFTorchInferenceModel):
shutil.rmtree("cache/")
self.patch_embedding()
lazy_loader.post_load_cleanup()
self.model.kai_model = self
utils.koboldai_vars.modeldim = self.get_hidden_size()

View File

@@ -87,13 +87,29 @@ torch_checkpoint_file_handles = {}
class CheckpointChunkCache:
"""Storage for common checkpoint weight files to speed up loading. In order
for this to be effective at all, weights must be loaded in ascending order
of (key, seek_offset)."""
of (key, seek_offset).
"""
# There is considerable room for improvement here; we could peek into the
# state dict and preload the N most frequent weight files or something, but
# this first implementation is on par with the speed of whatever the
# previous callback did.
file_name = None
key = None
handle = None
hit_data = {"hits": 0, "misses": 0}
@classmethod
def clear(cls) -> None:
def clear(cls, unload_model: bool = False) -> None:
if unload_model:
cls.hit_data["hits"] = 0
cls.hit_data["misses"] = 0
if cls.handle:
cls.handle.close()
cls.file_name = None
cls.key = None
cls.handle = None
@@ -140,20 +156,22 @@ class TorchLazyTensor(LazyTensor):
) -> torch.Tensor:
checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
# Most of the operations are just seeks, let's see if we can optimize that.
# Often we are using the same weight file to store multiple tensors, so
# let's cache the file handle to maintain a seek position and other
# fast stuff.
if (
CheckpointChunkCache.file_name != filename
or CheckpointChunkCache.key != self.key
or not CheckpointChunkCache.handle
):
# Flush cache if invalid
# Cache miss. Assuming weights are loaded in order of
# (key, seek_offset), this means we need to invalidate the cache.
print("!", end="", flush=True)
CheckpointChunkCache.hit_data["misses"] += 1
if CheckpointChunkCache.handle:
CheckpointChunkCache.handle.close()
CheckpointChunkCache.clear()
CheckpointChunkCache.file_name = filename
CheckpointChunkCache.key = self.key
@@ -165,6 +183,10 @@ class TorchLazyTensor(LazyTensor):
CheckpointChunkCache.handle = checkpoint.open(
f"{filename}/data/{self.key}", "r"
)
else:
# Cache hit. Hip hip hooray! :^)
print(".", end="", flush=True)
CheckpointChunkCache.hit_data["hits"] += 1
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype
@@ -173,7 +195,9 @@ class TorchLazyTensor(LazyTensor):
if dtype is torch.bool
else size
* (
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits
(torch.finfo if self.dtype.is_floating_point else torch.iinfo)(
self.dtype
).bits
>> 3
)
)
@@ -181,14 +205,14 @@ class TorchLazyTensor(LazyTensor):
assert isinstance(checkpoint, zipfile.ZipFile)
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
storage = STORAGE_TYPE_MAP[dtype].from_buffer(
storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little"
)
storage = torch.serialization._get_restore_location(map_location)(
storage, self.location
)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor = torch.tensor([], dtype=self.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride)
tensor.requires_grad = not no_grad and self.requires_grad
tensor._backward_hooks = self.backward_hooks
@@ -449,3 +473,15 @@ def use_lazy_load(
torch.load = old_torch_load
if dematerialized_modules:
init_empty_weights.__exit__(None, None, None)
def post_load_cleanup() -> None:
global torch_checkpoint_file_handles
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
CheckpointChunkCache.clear()
for v in torch_checkpoint_file_handles.values():
v.close()
torch_checkpoint_file_handles = {}

View File

@@ -196,7 +196,7 @@ def patch_transformers_for_lazyload() -> None:
):
if isinstance(param, LazyTensor):
print(".", end="", flush=True)
# Should always be true
param = param.materialize()
# END PATCH