(torch_lazy_loader.py) Add support for materializing from a ZipExtFile

This commit is contained in:
Gnome Ann 2022-03-02 13:08:21 -05:00
parent c338b52d68
commit 1ecc452dc8

View File

@ -3,7 +3,7 @@ from functools import reduce
import zipfile
import pickle
import torch
from typing import Any, Callable, Dict, Optional, Tuple, Type
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
class LazyTensor:
@ -23,13 +23,20 @@ class LazyTensor:
def __repr__(self):
return self.__view(repr)
def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor:
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor:
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.storage_type(0).dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
with checkpoint.open(f"archive/data/{self.key}", "r") as f:
if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r")
f.seek(self.storage_offset)
else:
f = checkpoint
try:
storage = self.storage_type.from_buffer(f.read(nbytes), "little")
finally:
if isinstance(checkpoint, zipfile.ZipFile):
f.close()
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride)