From fabbdf2bb1f1b0f3a9e383d7781aa17e066aac5b Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 2 Apr 2022 15:02:54 -0400 Subject: [PATCH] Lazy loader Python 3.6 compatibility The current lazy loader relies on a feature of the Python zipfile module that was added in Python 3.7.0: https://bugs.python.org/issue22908 This commit adds compatibility for Python 3.6. --- torch_lazy_loader.py | 2 +- tpu_mtj_backend.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 09166f5a..49a30931 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -95,7 +95,7 @@ class LazyTensor: nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) if isinstance(checkpoint, zipfile.ZipFile): f = checkpoint.open(f"archive/data/{self.key}", "r") - f.seek(self.seek_offset) + f.read(self.seek_offset) else: f = checkpoint try: diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index c7a8840f..9ccc4f30 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -887,6 +887,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): output_shards = config["cores_per_replica"] // checkpoint_shards import torch + import torch.utils.dlpack from tqdm.auto import tqdm move_xmap = jax.experimental.maps.xmap( @@ -1154,12 +1155,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo import torch_lazy_loader import torch from tqdm.auto import tqdm + import functools def callback(model_dict, f, **_): with zipfile.ZipFile(f, "r") as z: try: last_storage_key = None f = None + current_offset = 0 print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): @@ -1178,17 +1181,22 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if isinstance(f, zipfile.ZipExtFile): f.close() f = z.open(f"archive/data/{storage_key}") - current_offset = f.tell() + current_offset = 0 if current_offset != model_dict[key].seek_offset: - f.seek(model_dict[key].seek_offset - current_offset, 1) + f.read(model_dict[key].seek_offset - current_offset) + current_offset = model_dict[key].seek_offset spec = model_spec[key] transforms = set(spec.get("transforms", ())) if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): error = f"Duplicate key {repr(key)}" print("\n\nERROR: " + error, file=sys.stderr) raise RuntimeError(error) + size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1) + dtype = model_dict[key].dtype + nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) tensor = model_dict[key].materialize(f, map_location="cpu") model_dict[key] = tensor.to("meta") + current_offset += nbytes # MTJ requires certain mathematical operations to be performed # on tensors in order for them to be in the correct format