Lazy loader Python 3.6 compatibility
The current lazy loader relies on a feature of the Python zipfile module that was added in Python 3.7.0: https://bugs.python.org/issue22908 This commit adds compatibility for Python 3.6.
This commit is contained in:
parent
8368b20421
commit
fabbdf2bb1
|
@ -95,7 +95,7 @@ class LazyTensor:
|
||||||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||||
if isinstance(checkpoint, zipfile.ZipFile):
|
if isinstance(checkpoint, zipfile.ZipFile):
|
||||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||||
f.seek(self.seek_offset)
|
f.read(self.seek_offset)
|
||||||
else:
|
else:
|
||||||
f = checkpoint
|
f = checkpoint
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -887,6 +887,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
output_shards = config["cores_per_replica"] // checkpoint_shards
|
output_shards = config["cores_per_replica"] // checkpoint_shards
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.dlpack
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
move_xmap = jax.experimental.maps.xmap(
|
move_xmap = jax.experimental.maps.xmap(
|
||||||
|
@ -1154,12 +1155,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
import functools
|
||||||
|
|
||||||
def callback(model_dict, f, **_):
|
def callback(model_dict, f, **_):
|
||||||
with zipfile.ZipFile(f, "r") as z:
|
with zipfile.ZipFile(f, "r") as z:
|
||||||
try:
|
try:
|
||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
f = None
|
f = None
|
||||||
|
current_offset = 0
|
||||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||||
|
|
||||||
|
@ -1178,17 +1181,22 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
if isinstance(f, zipfile.ZipExtFile):
|
if isinstance(f, zipfile.ZipExtFile):
|
||||||
f.close()
|
f.close()
|
||||||
f = z.open(f"archive/data/{storage_key}")
|
f = z.open(f"archive/data/{storage_key}")
|
||||||
current_offset = f.tell()
|
current_offset = 0
|
||||||
if current_offset != model_dict[key].seek_offset:
|
if current_offset != model_dict[key].seek_offset:
|
||||||
f.seek(model_dict[key].seek_offset - current_offset, 1)
|
f.read(model_dict[key].seek_offset - current_offset)
|
||||||
|
current_offset = model_dict[key].seek_offset
|
||||||
spec = model_spec[key]
|
spec = model_spec[key]
|
||||||
transforms = set(spec.get("transforms", ()))
|
transforms = set(spec.get("transforms", ()))
|
||||||
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
|
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
|
||||||
error = f"Duplicate key {repr(key)}"
|
error = f"Duplicate key {repr(key)}"
|
||||||
print("\n\nERROR: " + error, file=sys.stderr)
|
print("\n\nERROR: " + error, file=sys.stderr)
|
||||||
raise RuntimeError(error)
|
raise RuntimeError(error)
|
||||||
|
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
|
||||||
|
dtype = model_dict[key].dtype
|
||||||
|
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||||
tensor = model_dict[key].materialize(f, map_location="cpu")
|
tensor = model_dict[key].materialize(f, map_location="cpu")
|
||||||
model_dict[key] = tensor.to("meta")
|
model_dict[key] = tensor.to("meta")
|
||||||
|
current_offset += nbytes
|
||||||
|
|
||||||
# MTJ requires certain mathematical operations to be performed
|
# MTJ requires certain mathematical operations to be performed
|
||||||
# on tensors in order for them to be in the correct format
|
# on tensors in order for them to be in the correct format
|
||||||
|
|
Loading…
Reference in New Issue