diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 604f6e69..85e1310e 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -1,15 +1,16 @@ import contextlib +from functools import reduce +import zipfile import pickle import torch from typing import Any, Callable, Dict, Optional, Tuple, Type class LazyTensor: - def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, nelements: int, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type self.key = key self.location = location - self.nelements = nelements self.storage_offset = storage_offset self.shape = shape self.stride = stride @@ -17,17 +18,21 @@ class LazyTensor: self.backward_hooks = backward_hooks def __view(self, f: Callable): - return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, nelements={f(self.nelements)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" def __repr__(self): return self.__view(repr) - def materialize(self, checkpoint: torch._C.PyTorchFileReader, map_location=None) -> torch.Tensor: - storage_dtype = self.storage_type(0).dtype - storage = checkpoint.get_storage_from_record(f"data/{self.key}", self.nelements, storage_dtype).storage() + def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor: + size = reduce(lambda x, y: x * y, self.shape, 1) + dtype = self.storage_type(0).dtype + nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) + with checkpoint.open(f"archive/data/{self.key}", "r") as f: + f.seek(self.storage_offset) + storage = self.storage_type.from_buffer(f.read(nbytes), "little") storage = torch.serialization._get_restore_location(map_location)(storage, self.location) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) - tensor.set_(storage, self.storage_offset, self.shape, self.stride) + tensor.set_(storage, 0, self.shape, self.stride) tensor.requires_grad = self.requires_grad tensor._backward_hooks = self.backward_hooks return tensor @@ -44,13 +49,8 @@ class _LazyUnpickler(pickle.Unpickler): assert isinstance(saved_id, tuple) typename = saved_id[0] assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" - - storage_type, key, location, nelements = saved_id[1:] - - if key not in self.lazy_loaded_storages: - self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements) - - return self.lazy_loaded_storages[key] + storage_type, key, location, _ = saved_id[1:] + return LazyTensor(storage_type, key, location) def load(self, *args, **kwargs): self.persistent_load = self.forced_persistent_load