diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 5ff9655b..d097675f 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -58,18 +58,18 @@ _EXTRA_STATE_KEY_SUFFIX = '_extra_state' class LazyTensor: - def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type self.key = key self.location = location - self.storage_offset = storage_offset + self.seek_offset = seek_offset self.shape = shape self.stride = stride self.requires_grad = requires_grad self.backward_hooks = backward_hooks def __view(self, f: Callable): - return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" def __repr__(self): return self.__view(repr) @@ -80,7 +80,7 @@ class LazyTensor: nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) if isinstance(checkpoint, zipfile.ZipFile): f = checkpoint.open(f"archive/data/{self.key}", "r") - f.seek(self.storage_offset) + f.seek(self.seek_offset) else: f = checkpoint try: @@ -118,9 +118,10 @@ class _LazyUnpickler(pickle.Unpickler): def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): - lazy_storage.storage_offset = storage_offset lazy_storage.shape = shape lazy_storage.stride = stride + dtype = lazy_storage.storage_type(0).dtype + lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) return lazy_storage