More cleaning

This commit is contained in:
somebody
2023-05-28 13:22:32 -05:00
parent 14241fc156
commit ed0728188a
3 changed files with 48 additions and 11 deletions

View File

@@ -238,6 +238,7 @@ class model_backend(HFTorchInferenceModel):
shutil.rmtree("cache/") shutil.rmtree("cache/")
self.patch_embedding() self.patch_embedding()
lazy_loader.post_load_cleanup()
self.model.kai_model = self self.model.kai_model = self
utils.koboldai_vars.modeldim = self.get_hidden_size() utils.koboldai_vars.modeldim = self.get_hidden_size()

View File

@@ -87,13 +87,29 @@ torch_checkpoint_file_handles = {}
class CheckpointChunkCache: class CheckpointChunkCache:
"""Storage for common checkpoint weight files to speed up loading. In order """Storage for common checkpoint weight files to speed up loading. In order
for this to be effective at all, weights must be loaded in ascending order for this to be effective at all, weights must be loaded in ascending order
of (key, seek_offset).""" of (key, seek_offset).
"""
# There is considerable room for improvement here; we could peek into the
# state dict and preload the N most frequent weight files or something, but
# this first implementation is on par with the speed of whatever the
# previous callback did.
file_name = None file_name = None
key = None key = None
handle = None handle = None
hit_data = {"hits": 0, "misses": 0}
@classmethod @classmethod
def clear(cls) -> None: def clear(cls, unload_model: bool = False) -> None:
if unload_model:
cls.hit_data["hits"] = 0
cls.hit_data["misses"] = 0
if cls.handle:
cls.handle.close()
cls.file_name = None cls.file_name = None
cls.key = None cls.key = None
cls.handle = None cls.handle = None
@@ -140,20 +156,22 @@ class TorchLazyTensor(LazyTensor):
) -> torch.Tensor: ) -> torch.Tensor:
checkpoint = torch_checkpoint_file_handles[self.file_name] checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0] filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
# Most of the operations are just seeks, let's see if we can optimize that. # Often we are using the same weight file to store multiple tensors, so
# let's cache the file handle to maintain a seek position and other
# fast stuff.
if ( if (
CheckpointChunkCache.file_name != filename CheckpointChunkCache.file_name != filename
or CheckpointChunkCache.key != self.key or CheckpointChunkCache.key != self.key
or not CheckpointChunkCache.handle or not CheckpointChunkCache.handle
): ):
# Flush cache if invalid # Cache miss. Assuming weights are loaded in order of
# (key, seek_offset), this means we need to invalidate the cache.
print("!", end="", flush=True) print("!", end="", flush=True)
CheckpointChunkCache.hit_data["misses"] += 1
if CheckpointChunkCache.handle: CheckpointChunkCache.clear()
CheckpointChunkCache.handle.close()
CheckpointChunkCache.file_name = filename CheckpointChunkCache.file_name = filename
CheckpointChunkCache.key = self.key CheckpointChunkCache.key = self.key
@@ -165,6 +183,10 @@ class TorchLazyTensor(LazyTensor):
CheckpointChunkCache.handle = checkpoint.open( CheckpointChunkCache.handle = checkpoint.open(
f"{filename}/data/{self.key}", "r" f"{filename}/data/{self.key}", "r"
) )
else:
# Cache hit. Hip hip hooray! :^)
print(".", end="", flush=True)
CheckpointChunkCache.hit_data["hits"] += 1
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype dtype = self.dtype
@@ -173,7 +195,9 @@ class TorchLazyTensor(LazyTensor):
if dtype is torch.bool if dtype is torch.bool
else size else size
* ( * (
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits (torch.finfo if self.dtype.is_floating_point else torch.iinfo)(
self.dtype
).bits
>> 3 >> 3
) )
) )
@@ -181,14 +205,14 @@ class TorchLazyTensor(LazyTensor):
assert isinstance(checkpoint, zipfile.ZipFile) assert isinstance(checkpoint, zipfile.ZipFile)
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET) CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
storage = STORAGE_TYPE_MAP[dtype].from_buffer( storage = STORAGE_TYPE_MAP[self.dtype].from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little" CheckpointChunkCache.handle.read(nbytes), "little"
) )
storage = torch.serialization._get_restore_location(map_location)( storage = torch.serialization._get_restore_location(map_location)(
storage, self.location storage, self.location
) )
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor = torch.tensor([], dtype=self.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride) tensor.set_(storage, 0, self.shape, self.stride)
tensor.requires_grad = not no_grad and self.requires_grad tensor.requires_grad = not no_grad and self.requires_grad
tensor._backward_hooks = self.backward_hooks tensor._backward_hooks = self.backward_hooks
@@ -449,3 +473,15 @@ def use_lazy_load(
torch.load = old_torch_load torch.load = old_torch_load
if dematerialized_modules: if dematerialized_modules:
init_empty_weights.__exit__(None, None, None) init_empty_weights.__exit__(None, None, None)
def post_load_cleanup() -> None:
global torch_checkpoint_file_handles
print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data)
CheckpointChunkCache.clear()
for v in torch_checkpoint_file_handles.values():
v.close()
torch_checkpoint_file_handles = {}

View File

@@ -196,7 +196,7 @@ def patch_transformers_for_lazyload() -> None:
): ):
if isinstance(param, LazyTensor): if isinstance(param, LazyTensor):
print(".", end="", flush=True) # Should always be true
param = param.materialize() param = param.materialize()
# END PATCH # END PATCH