(torch_lazy_loader.py) Handle checkpoints with merged storage blocks

This commit is contained in:
Gnome Ann
2022-03-02 01:02:35 -05:00
parent 4fa4dbac50
commit c338b52d68

View File

@ -1,15 +1,16 @@
import contextlib import contextlib
from functools import reduce
import zipfile
import pickle import pickle
import torch import torch
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
class LazyTensor: class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, nelements: int, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type self.storage_type = storage_type
self.key = key self.key = key
self.location = location self.location = location
self.nelements = nelements
self.storage_offset = storage_offset self.storage_offset = storage_offset
self.shape = shape self.shape = shape
self.stride = stride self.stride = stride
@ -17,17 +18,21 @@ class LazyTensor:
self.backward_hooks = backward_hooks self.backward_hooks = backward_hooks
def __view(self, f: Callable): def __view(self, f: Callable):
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, nelements={f(self.nelements)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
def __repr__(self): def __repr__(self):
return self.__view(repr) return self.__view(repr)
def materialize(self, checkpoint: torch._C.PyTorchFileReader, map_location=None) -> torch.Tensor: def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor:
storage_dtype = self.storage_type(0).dtype size = reduce(lambda x, y: x * y, self.shape, 1)
storage = checkpoint.get_storage_from_record(f"data/{self.key}", self.nelements, storage_dtype).storage() dtype = self.storage_type(0).dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
with checkpoint.open(f"archive/data/{self.key}", "r") as f:
f.seek(self.storage_offset)
storage = self.storage_type.from_buffer(f.read(nbytes), "little")
storage = torch.serialization._get_restore_location(map_location)(storage, self.location) storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor.set_(storage, self.storage_offset, self.shape, self.stride) tensor.set_(storage, 0, self.shape, self.stride)
tensor.requires_grad = self.requires_grad tensor.requires_grad = self.requires_grad
tensor._backward_hooks = self.backward_hooks tensor._backward_hooks = self.backward_hooks
return tensor return tensor
@ -44,13 +49,8 @@ class _LazyUnpickler(pickle.Unpickler):
assert isinstance(saved_id, tuple) assert isinstance(saved_id, tuple)
typename = saved_id[0] typename = saved_id[0]
assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, _ = saved_id[1:]
storage_type, key, location, nelements = saved_id[1:] return LazyTensor(storage_type, key, location)
if key not in self.lazy_loaded_storages:
self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements)
return self.lazy_loaded_storages[key]
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
self.persistent_load = self.forced_persistent_load self.persistent_load = self.forced_persistent_load