Merge pull request #108 from VE-FORBRYDERNE/lazy-loader

Lazy loader Python 3.6 compatibility
This commit is contained in:
henk717 2022-04-03 01:15:48 +02:00 committed by GitHub
commit 0882ba165c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View File

@ -95,7 +95,7 @@ class LazyTensor:
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r")
f.seek(self.seek_offset)
f.read(self.seek_offset)
else:
f = checkpoint
try:

View File

@ -887,6 +887,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
output_shards = config["cores_per_replica"] // checkpoint_shards
import torch
import torch.utils.dlpack
from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap(
@ -1154,12 +1155,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import torch_lazy_loader
import torch
from tqdm.auto import tqdm
import functools
def callback(model_dict, f, **_):
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
f = None
current_offset = 0
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
@ -1178,17 +1181,22 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if isinstance(f, zipfile.ZipExtFile):
f.close()
f = z.open(f"archive/data/{storage_key}")
current_offset = f.tell()
current_offset = 0
if current_offset != model_dict[key].seek_offset:
f.seek(model_dict[key].seek_offset - current_offset, 1)
f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
spec = model_spec[key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
dtype = model_dict[key].dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
tensor = model_dict[key].materialize(f, map_location="cpu")
model_dict[key] = tensor.to("meta")
current_offset += nbytes
# MTJ requires certain mathematical operations to be performed
# on tensors in order for them to be in the correct format