Fix torch_lazy_loader seek offset calculation
This commit is contained in:
parent
24bc0f81ea
commit
1515996fca
|
@ -58,18 +58,18 @@ _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
|||
|
||||
|
||||
class LazyTensor:
|
||||
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
|
||||
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
self.location = location
|
||||
self.storage_offset = storage_offset
|
||||
self.seek_offset = seek_offset
|
||||
self.shape = shape
|
||||
self.stride = stride
|
||||
self.requires_grad = requires_grad
|
||||
self.backward_hooks = backward_hooks
|
||||
|
||||
def __view(self, f: Callable):
|
||||
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
|
||||
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__view(repr)
|
||||
|
@ -80,7 +80,7 @@ class LazyTensor:
|
|||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||
f.seek(self.storage_offset)
|
||||
f.seek(self.seek_offset)
|
||||
else:
|
||||
f = checkpoint
|
||||
try:
|
||||
|
@ -118,9 +118,10 @@ class _LazyUnpickler(pickle.Unpickler):
|
|||
|
||||
|
||||
def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
|
||||
lazy_storage.storage_offset = storage_offset
|
||||
lazy_storage.shape = shape
|
||||
lazy_storage.stride = stride
|
||||
dtype = lazy_storage.storage_type(0).dtype
|
||||
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
return lazy_storage
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue