diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 85e1310e..4a29d0c8 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -3,7 +3,7 @@ from functools import reduce import zipfile import pickle import torch -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union class LazyTensor: @@ -23,13 +23,20 @@ class LazyTensor: def __repr__(self): return self.__view(repr) - def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor: + def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.storage_type(0).dtype nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) - with checkpoint.open(f"archive/data/{self.key}", "r") as f: + if isinstance(checkpoint, zipfile.ZipFile): + f = checkpoint.open(f"archive/data/{self.key}", "r") f.seek(self.storage_offset) + else: + f = checkpoint + try: storage = self.storage_type.from_buffer(f.read(nbytes), "little") + finally: + if isinstance(checkpoint, zipfile.ZipFile): + f.close() storage = torch.serialization._get_restore_location(map_location)(storage, self.location) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor.set_(storage, 0, self.shape, self.stride)