(torch_lazy_loader.py) Add support for materializing from a ZipExtFile

This commit is contained in:
Gnome Ann 2022-03-02 13:08:21 -05:00
parent c338b52d68
commit 1ecc452dc8

View File

@ -3,7 +3,7 @@ from functools import reduce
import zipfile import zipfile
import pickle import pickle
import torch import torch
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
class LazyTensor: class LazyTensor:
@ -23,13 +23,20 @@ class LazyTensor:
def __repr__(self): def __repr__(self):
return self.__view(repr) return self.__view(repr)
def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor: def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor:
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.storage_type(0).dtype dtype = self.storage_type(0).dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
with checkpoint.open(f"archive/data/{self.key}", "r") as f: if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r")
f.seek(self.storage_offset) f.seek(self.storage_offset)
else:
f = checkpoint
try:
storage = self.storage_type.from_buffer(f.read(nbytes), "little") storage = self.storage_type.from_buffer(f.read(nbytes), "little")
finally:
if isinstance(checkpoint, zipfile.ZipFile):
f.close()
storage = torch.serialization._get_restore_location(map_location)(storage, self.location) storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride) tensor.set_(storage, 0, self.shape, self.stride)