Add PyTorch 1.11 support for lazy loader

This commit is contained in:
Gnome Ann 2022-03-17 12:51:41 -04:00
parent 9235754eb9
commit eaf190469d
4 changed files with 27 additions and 8 deletions

View File

@ -2,7 +2,7 @@ transformers>=4.17
Flask Flask
Flask-SocketIO Flask-SocketIO
requests requests
torch==1.10.* torch>=1.9
flask-cloudflared flask-cloudflared
flask-ngrok flask-ngrok
eventlet eventlet

View File

@ -1,3 +1,4 @@
torch >= 1.9
numpy numpy
tqdm tqdm
requests requests

View File

@ -57,11 +57,26 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
STORAGE_TYPE_MAP = {
torch.float64: torch.DoubleStorage,
torch.float32: torch.FloatStorage,
torch.float16: torch.HalfStorage,
torch.int64: torch.LongStorage,
torch.int32: torch.IntStorage,
torch.int16: torch.ShortStorage,
torch.int8: torch.CharStorage,
torch.uint8: torch.ByteStorage,
torch.bool: torch.BoolStorage,
torch.bfloat16: torch.BFloat16Storage,
}
class LazyTensor: class LazyTensor:
def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
self.storage_type = storage_type self.storage_type = storage_type
self.key = key self.key = key
self.location = location self.location = location
self.dtype = dtype
self.seek_offset = seek_offset self.seek_offset = seek_offset
self.shape = shape self.shape = shape
self.stride = stride self.stride = stride
@ -69,14 +84,14 @@ class LazyTensor:
self.backward_hooks = backward_hooks self.backward_hooks = backward_hooks
def __view(self, f: Callable): def __view(self, f: Callable):
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
def __repr__(self): def __repr__(self):
return self.__view(repr) return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor:
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.storage_type(0).dtype dtype = self.dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile): if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r") f = checkpoint.open(f"archive/data/{self.key}", "r")
@ -84,7 +99,7 @@ class LazyTensor:
else: else:
f = checkpoint f = checkpoint
try: try:
storage = self.storage_type.from_buffer(f.read(nbytes), "little") storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
finally: finally:
if isinstance(checkpoint, zipfile.ZipFile): if isinstance(checkpoint, zipfile.ZipFile):
f.close() f.close()
@ -120,7 +135,10 @@ class _LazyUnpickler(pickle.Unpickler):
def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
lazy_storage.shape = shape lazy_storage.shape = shape
lazy_storage.stride = stride lazy_storage.stride = stride
dtype = lazy_storage.storage_type(0).dtype dtype = lazy_storage.storage_type.dtype
if not isinstance(dtype, torch.dtype):
dtype = lazy_storage.storage_type(0).dtype
lazy_storage.dtype = dtype
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
return lazy_storage return lazy_storage

View File

@ -961,7 +961,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
# the least possible memory usage, we create them as meta # the least possible memory usage, we create them as meta
# tensors, which don't take up any actual CPU or TPU memory. # tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec: if key not in model_spec:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta") model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
continue continue
storage_key = model_dict[key].key storage_key = model_dict[key].key