diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index 85ed495d..8591bc96 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -46,7 +46,6 @@ POSSIBILITY OF SUCH DAMAGE. import contextlib from functools import reduce -import itertools import zipfile import pickle import torch @@ -54,8 +53,7 @@ import numpy as np import collections import _codecs import os -from torch.nn import Module -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type # Safetensors is a dependency for the local version, TPU/Colab doesn't # support it yet. @@ -85,6 +83,22 @@ STORAGE_TYPE_MAP = { # Storage of zipfile handles for each shard torch_checkpoint_file_handles = {} + +class CheckpointChunkCache: + """Storage for common checkpoint weight files to speed up loading. In order + for this to be effective at all, weights must be loaded in ascending order + of (key, seek_offset).""" + file_name = None + key = None + handle = None + + @classmethod + def clear(cls) -> None: + cls.file_name = None + cls.key = None + cls.handle = None + + class LazyTensor: pass @@ -121,36 +135,37 @@ class TorchLazyTensor(LazyTensor): def materialize( self, - checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None, map_location=None, no_grad=True, - filename="pytorch_model.bin", ) -> torch.Tensor: + checkpoint = torch_checkpoint_file_handles[self.file_name] + filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0] - # if f not in torch_tensor_container_file_map: - # torch_tensor_container_file_map[f] = [] + # Most of the operations are just seeks, let's see if we can optimize that. + if ( + CheckpointChunkCache.file_name != filename + or CheckpointChunkCache.key != self.key + or not CheckpointChunkCache.handle + ): + # Flush cache if invalid + print("!", end="", flush=True) - # with zipfile.ZipFile(f, "r") as z: - # paths = z.namelist() + if CheckpointChunkCache.handle: + CheckpointChunkCache.handle.close() - # for name in paths: - # val = name.split("/data/")[-1] - # if not val.isdecimal(): - # continue - # torch_tensor_container_file_map[f].append(int(val)) - # torch_tensor_container_file_map[f].sort() - # print(torch_tensor_container_file_map) + CheckpointChunkCache.file_name = filename + CheckpointChunkCache.key = self.key + try: + CheckpointChunkCache.handle = checkpoint.open( + f"archive/data/{self.key}", "r" + ) + except KeyError: + CheckpointChunkCache.handle = checkpoint.open( + f"{filename}/data/{self.key}", "r" + ) - - - - if not checkpoint: - checkpoint = torch_checkpoint_file_handles[self.file_name] - filename = self.file_name - - filename = os.path.basename(os.path.normpath(filename)).split(".")[0] size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.dtype nbytes = ( @@ -163,21 +178,12 @@ class TorchLazyTensor(LazyTensor): ) ) - if isinstance(checkpoint, zipfile.ZipFile): - try: - f = checkpoint.open(f"archive/data/{self.key}", "r") - except: - f = checkpoint.open(f"{filename}/data/{self.key}", "r") - f.seek(self.seek_offset, os.SEEK_CUR) - # f.read(self.seek_offset) - else: - f = checkpoint + assert isinstance(checkpoint, zipfile.ZipFile) - try: - storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little") - finally: - if isinstance(checkpoint, zipfile.ZipFile): - f.close() + CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET) + storage = STORAGE_TYPE_MAP[dtype].from_buffer( + CheckpointChunkCache.handle.read(nbytes), "little" + ) storage = torch.serialization._get_restore_location(map_location)( storage, self.location @@ -277,9 +283,7 @@ class _LazyUnpickler(RestrictedUnpickler): typename == "storage" ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" storage_type, key, location, _ = saved_id[1:] - return TorchLazyTensor( - storage_type, key, location - ) + return TorchLazyTensor(storage_type, key, location) def load(self, *args, **kwargs): retval = super().load(*args, **kwargs) @@ -361,13 +365,6 @@ def patch_safetensors(callback): transformers.modeling_utils.safe_load_file = safetensors_load -def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor): - with zipfile.ZipFile(file, "r") as z: - storage_key = lazy_tensor.key - ziproot = z.namelist()[0].split("/")[0] - f = z.open(f"{ziproot}/data/{storage_key}") - # TODO: Maybe some file seeking - return f @contextlib.contextmanager def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler): @@ -418,7 +415,7 @@ def use_lazy_load( if f not in torch_checkpoint_file_handles: torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r") - for k,v in model_dict.items(): + for k, v in model_dict.items(): v.file_name = f if callback is not None: diff --git a/modeling/patches.py b/modeling/patches.py index 23d0301c..52fe9e10 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -126,7 +126,6 @@ def patch_transformers_generation() -> None: transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init -CURRENT_CHECKPOINT = None def patch_transformers_for_lazyload() -> None: import torch import inspect @@ -158,8 +157,6 @@ def patch_transformers_for_lazyload() -> None: """ - print("DEVMAP", device_map) - # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model # - deepspeed zero 3 support # - need to copy metadata if any - see _load_state_dict_into_model @@ -186,13 +183,22 @@ def patch_transformers_for_lazyload() -> None: for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) - for param_name, param in state_dict.items(): +# BEGIN PATCH + for param_name, param in sorted( + state_dict.items(), + # State dict must be ordered in this manner to make the caching in + # lazy_loader.py effective + key=lambda x: ( + # NOTE: Assuming key is just decimal + int(x[1].key), + x[1].seek_offset, + ), + ): - # BEGIN PATCH if isinstance(param, LazyTensor): print(".", end="", flush=True) param = param.materialize() - # END PATCH +# END PATCH # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if (