From eaf190469d8cec8ea03886a3f9a3622599bb2785 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 17 Mar 2022 12:51:41 -0400 Subject: [PATCH 1/3] Add PyTorch 1.11 support for lazy loader --- requirements.txt | 4 ++-- requirements_mtj.txt | 1 + torch_lazy_loader.py | 28 +++++++++++++++++++++++----- tpu_mtj_backend.py | 2 +- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8b1b36a5..2563073c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ transformers>=4.17 Flask Flask-SocketIO requests -torch==1.10.* +torch>=1.9 flask-cloudflared flask-ngrok eventlet @@ -10,4 +10,4 @@ lupa==1.10 markdown bleach sentencepiece -protobuf \ No newline at end of file +protobuf diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 9805328f..59cb5a5f 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -1,3 +1,4 @@ +torch >= 1.9 numpy tqdm requests diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index d097675f..36cdcf2c 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -57,11 +57,26 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union _EXTRA_STATE_KEY_SUFFIX = '_extra_state' +STORAGE_TYPE_MAP = { + torch.float64: torch.DoubleStorage, + torch.float32: torch.FloatStorage, + torch.float16: torch.HalfStorage, + torch.int64: torch.LongStorage, + torch.int32: torch.IntStorage, + torch.int16: torch.ShortStorage, + torch.int8: torch.CharStorage, + torch.uint8: torch.ByteStorage, + torch.bool: torch.BoolStorage, + torch.bfloat16: torch.BFloat16Storage, +} + + class LazyTensor: - def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type self.key = key self.location = location + self.dtype = dtype self.seek_offset = seek_offset self.shape = shape self.stride = stride @@ -69,14 +84,14 @@ class LazyTensor: self.backward_hooks = backward_hooks def __view(self, f: Callable): - return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" def __repr__(self): return self.__view(repr) def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: size = reduce(lambda x, y: x * y, self.shape, 1) - dtype = self.storage_type(0).dtype + dtype = self.dtype nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) if isinstance(checkpoint, zipfile.ZipFile): f = checkpoint.open(f"archive/data/{self.key}", "r") @@ -84,7 +99,7 @@ class LazyTensor: else: f = checkpoint try: - storage = self.storage_type.from_buffer(f.read(nbytes), "little") + storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little") finally: if isinstance(checkpoint, zipfile.ZipFile): f.close() @@ -120,7 +135,10 @@ class _LazyUnpickler(pickle.Unpickler): def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): lazy_storage.shape = shape lazy_storage.stride = stride - dtype = lazy_storage.storage_type(0).dtype + dtype = lazy_storage.storage_type.dtype + if not isinstance(dtype, torch.dtype): + dtype = lazy_storage.storage_type(0).dtype + lazy_storage.dtype = dtype lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) return lazy_storage diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index b13e3aa3..e2dc70a5 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -961,7 +961,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo # the least possible memory usage, we create them as meta # tensors, which don't take up any actual CPU or TPU memory. if key not in model_spec: - model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta") + model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") continue storage_key = model_dict[key].key From ef21ab9c91492c55323d1839a0a562b18d9b1687 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 17 Mar 2022 14:10:51 -0400 Subject: [PATCH 2/3] PyTorch 1.9 lazy loader compatibility bugfix --- torch_lazy_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 36cdcf2c..09166f5a 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -195,7 +195,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss missing_keys.append(key) extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) elif strict: From c444260eac59a582f2d8bad71567147c0bdf8079 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 17 Mar 2022 15:16:56 -0400 Subject: [PATCH 3/3] Silence PyTorch warning about transposing tensors with dimension != 2 --- tpu_mtj_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index e2dc70a5..0ba07802 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -988,7 +988,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tensor /= params["cores_per_replica"] if "vocab_pad" in transforms: tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"])) - if "no_transpose" not in transforms: + if "no_transpose" not in transforms and tensor.ndim == 2: tensor = tensor.T tensor.unsqueeze_(0) if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: