Fix torch_lazy_loader seek offset calculation

This commit is contained in:
Gnome Ann 2022-03-03 23:53:40 -05:00
parent 24bc0f81ea
commit 1515996fca
1 changed files with 6 additions and 5 deletions

View File

@ -58,18 +58,18 @@ _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type
self.key = key
self.location = location
self.storage_offset = storage_offset
self.seek_offset = seek_offset
self.shape = shape
self.stride = stride
self.requires_grad = requires_grad
self.backward_hooks = backward_hooks
def __view(self, f: Callable):
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
def __repr__(self):
return self.__view(repr)
@ -80,7 +80,7 @@ class LazyTensor:
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r")
f.seek(self.storage_offset)
f.seek(self.seek_offset)
else:
f = checkpoint
try:
@ -118,9 +118,10 @@ class _LazyUnpickler(pickle.Unpickler):
def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
lazy_storage.storage_offset = storage_offset
lazy_storage.shape = shape
lazy_storage.stride = stride
dtype = lazy_storage.storage_type(0).dtype
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
return lazy_storage