diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index 539d2018..d4c254b8 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -238,6 +238,7 @@ class model_backend(HFTorchInferenceModel): shutil.rmtree("cache/") self.patch_embedding() + lazy_loader.post_load_cleanup() self.model.kai_model = self utils.koboldai_vars.modeldim = self.get_hidden_size() diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index 8591bc96..54cbb912 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -87,13 +87,29 @@ torch_checkpoint_file_handles = {} class CheckpointChunkCache: """Storage for common checkpoint weight files to speed up loading. In order for this to be effective at all, weights must be loaded in ascending order - of (key, seek_offset).""" + of (key, seek_offset). + """ + + # There is considerable room for improvement here; we could peek into the + # state dict and preload the N most frequent weight files or something, but + # this first implementation is on par with the speed of whatever the + # previous callback did. + file_name = None key = None handle = None + hit_data = {"hits": 0, "misses": 0} + @classmethod - def clear(cls) -> None: + def clear(cls, unload_model: bool = False) -> None: + if unload_model: + cls.hit_data["hits"] = 0 + cls.hit_data["misses"] = 0 + + if cls.handle: + cls.handle.close() + cls.file_name = None cls.key = None cls.handle = None @@ -140,20 +156,22 @@ class TorchLazyTensor(LazyTensor): ) -> torch.Tensor: checkpoint = torch_checkpoint_file_handles[self.file_name] - filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0] - # Most of the operations are just seeks, let's see if we can optimize that. + # Often we are using the same weight file to store multiple tensors, so + # let's cache the file handle to maintain a seek position and other + # fast stuff. if ( CheckpointChunkCache.file_name != filename or CheckpointChunkCache.key != self.key or not CheckpointChunkCache.handle ): - # Flush cache if invalid + # Cache miss. Assuming weights are loaded in order of + # (key, seek_offset), this means we need to invalidate the cache. print("!", end="", flush=True) + CheckpointChunkCache.hit_data["misses"] += 1 - if CheckpointChunkCache.handle: - CheckpointChunkCache.handle.close() + CheckpointChunkCache.clear() CheckpointChunkCache.file_name = filename CheckpointChunkCache.key = self.key @@ -165,6 +183,10 @@ class TorchLazyTensor(LazyTensor): CheckpointChunkCache.handle = checkpoint.open( f"{filename}/data/{self.key}", "r" ) + else: + # Cache hit. Hip hip hooray! :^) + print(".", end="", flush=True) + CheckpointChunkCache.hit_data["hits"] += 1 size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.dtype @@ -173,7 +195,9 @@ class TorchLazyTensor(LazyTensor): if dtype is torch.bool else size * ( - (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits + (torch.finfo if self.dtype.is_floating_point else torch.iinfo)( + self.dtype + ).bits >> 3 ) ) @@ -181,14 +205,14 @@ class TorchLazyTensor(LazyTensor): assert isinstance(checkpoint, zipfile.ZipFile) CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET) - storage = STORAGE_TYPE_MAP[dtype].from_buffer( + storage = STORAGE_TYPE_MAP[self.dtype].from_buffer( CheckpointChunkCache.handle.read(nbytes), "little" ) storage = torch.serialization._get_restore_location(map_location)( storage, self.location ) - tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) + tensor = torch.tensor([], dtype=self.dtype, device=storage.device) tensor.set_(storage, 0, self.shape, self.stride) tensor.requires_grad = not no_grad and self.requires_grad tensor._backward_hooks = self.backward_hooks @@ -449,3 +473,15 @@ def use_lazy_load( torch.load = old_torch_load if dematerialized_modules: init_empty_weights.__exit__(None, None, None) + + +def post_load_cleanup() -> None: + global torch_checkpoint_file_handles + + print("CheckpointChunkCache Hit Data:", CheckpointChunkCache.hit_data) + CheckpointChunkCache.clear() + + for v in torch_checkpoint_file_handles.values(): + v.close() + + torch_checkpoint_file_handles = {} diff --git a/modeling/patches.py b/modeling/patches.py index 52fe9e10..5c0573f7 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -196,7 +196,7 @@ def patch_transformers_for_lazyload() -> None: ): if isinstance(param, LazyTensor): - print(".", end="", flush=True) + # Should always be true param = param.materialize() # END PATCH