From a0344b429c17f1d5d9cea5b5ffaa77c36fc17850 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 1 Mar 2022 15:40:44 -0500 Subject: [PATCH 01/14] Upload torch_lazy_loader.py --- torch_lazy_loader.py | 95 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 torch_lazy_loader.py diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py new file mode 100644 index 00000000..c5211050 --- /dev/null +++ b/torch_lazy_loader.py @@ -0,0 +1,95 @@ +import contextlib +import pickle +import torch +from typing import Any, Callable, Dict, Optional, Tuple, Type + + +class LazyTensor: + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, nelements: int, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + self.storage_type = storage_type + self.key = key + self.location = location + self.nelements = nelements + self.storage_offset = storage_offset + self.shape = shape + self.stride = stride + self.requires_grad = requires_grad + self.backward_hooks = backward_hooks + + def __view(self, f: Callable): + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, nelements={f(self.nelements)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + + def __repr__(self): + return self.__view(repr) + + def materialize(self, checkpoint: torch._C.PyTorchFileReader, map_location=None) -> torch.Tensor: + storage_dtype = self.storage_type(0).dtype + storage = checkpoint.get_storage_from_record(f"data/{self.key}", self.nelements, storage_dtype).storage() + storage = torch.serialization._get_restore_location(map_location)(storage, self.location) + tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) + tensor.set_(storage, self.storage_offset, self.shape, self.stride) + tensor.requires_grad = self.requires_grad + tensor._backward_hooks = self.backward_hooks + return tensor + + +class _LazyUnpickler(pickle.Unpickler): + lazy_loaded_storages: Dict[str, LazyTensor] + + def __init__(self, *args, **kwargs): + self.lazy_loaded_storages = {} + return super().__init__(*args, **kwargs) + + def forced_persistent_load(self, saved_id): + assert isinstance(saved_id, tuple) + typename = saved_id[0] + assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + + storage_type, key, location, nelements = saved_id[1:] + + if key not in self.lazy_loaded_storages: + self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements) + + return self.lazy_loaded_storages[key] + + def load(self, *args, **kwargs): + self.persistent_load = self.forced_persistent_load + retval = super().load(*args, **kwargs) + self.lazy_loaded_storages = {} + return retval + + +def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): + lazy_storage.storage_offset = storage_offset + lazy_storage.shape = shape + lazy_storage.stride = stride + return lazy_storage + + +@contextlib.contextmanager +def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): + if not enable: + yield False + return + + old_unpickler = pickle.Unpickler + pickle.Unpickler = _LazyUnpickler + + old_rebuild_tensor = torch._utils._rebuild_tensor + torch._utils._rebuild_tensor = _rebuild_tensor + + old_torch_load = torch.load + + def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + if callback is not None: + callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + return retval + + torch.load = torch_load + + yield True + + pickle.Unpickler = old_unpickler + torch._utils._rebuild_tensor = old_rebuild_tensor + torch.load = old_torch_load From 4fa4dbac50e4aa9a7fa0f83a315376b80fa980d0 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 1 Mar 2022 19:30:22 -0500 Subject: [PATCH 02/14] Clean up when error is thrown in `use_lazy_torch_load` --- torch_lazy_loader.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index c5211050..604f6e69 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -49,7 +49,7 @@ class _LazyUnpickler(pickle.Unpickler): if key not in self.lazy_loaded_storages: self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements) - + return self.lazy_loaded_storages[key] def load(self, *args, **kwargs): @@ -72,24 +72,26 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): yield False return - old_unpickler = pickle.Unpickler - pickle.Unpickler = _LazyUnpickler + try: + old_unpickler = pickle.Unpickler + pickle.Unpickler = _LazyUnpickler - old_rebuild_tensor = torch._utils._rebuild_tensor - torch._utils._rebuild_tensor = _rebuild_tensor + old_rebuild_tensor = torch._utils._rebuild_tensor + torch._utils._rebuild_tensor = _rebuild_tensor - old_torch_load = torch.load + old_torch_load = torch.load - def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): - retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - if callback is not None: - callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - return retval + def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + if callback is not None: + callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + return retval - torch.load = torch_load + torch.load = torch_load - yield True + yield True - pickle.Unpickler = old_unpickler - torch._utils._rebuild_tensor = old_rebuild_tensor - torch.load = old_torch_load + finally: + pickle.Unpickler = old_unpickler + torch._utils._rebuild_tensor = old_rebuild_tensor + torch.load = old_torch_load From c338b52d68fdd909b30246e6c6b1c0ef29ed86dc Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 2 Mar 2022 01:02:35 -0500 Subject: [PATCH 03/14] (torch_lazy_loader.py) Handle checkpoints with merged storage blocks --- torch_lazy_loader.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 604f6e69..85e1310e 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -1,15 +1,16 @@ import contextlib +from functools import reduce +import zipfile import pickle import torch from typing import Any, Callable, Dict, Optional, Tuple, Type class LazyTensor: - def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, nelements: int, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type self.key = key self.location = location - self.nelements = nelements self.storage_offset = storage_offset self.shape = shape self.stride = stride @@ -17,17 +18,21 @@ class LazyTensor: self.backward_hooks = backward_hooks def __view(self, f: Callable): - return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, nelements={f(self.nelements)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" def __repr__(self): return self.__view(repr) - def materialize(self, checkpoint: torch._C.PyTorchFileReader, map_location=None) -> torch.Tensor: - storage_dtype = self.storage_type(0).dtype - storage = checkpoint.get_storage_from_record(f"data/{self.key}", self.nelements, storage_dtype).storage() + def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor: + size = reduce(lambda x, y: x * y, self.shape, 1) + dtype = self.storage_type(0).dtype + nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) + with checkpoint.open(f"archive/data/{self.key}", "r") as f: + f.seek(self.storage_offset) + storage = self.storage_type.from_buffer(f.read(nbytes), "little") storage = torch.serialization._get_restore_location(map_location)(storage, self.location) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) - tensor.set_(storage, self.storage_offset, self.shape, self.stride) + tensor.set_(storage, 0, self.shape, self.stride) tensor.requires_grad = self.requires_grad tensor._backward_hooks = self.backward_hooks return tensor @@ -44,13 +49,8 @@ class _LazyUnpickler(pickle.Unpickler): assert isinstance(saved_id, tuple) typename = saved_id[0] assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" - - storage_type, key, location, nelements = saved_id[1:] - - if key not in self.lazy_loaded_storages: - self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements) - - return self.lazy_loaded_storages[key] + storage_type, key, location, _ = saved_id[1:] + return LazyTensor(storage_type, key, location) def load(self, *args, **kwargs): self.persistent_load = self.forced_persistent_load From 1ecc452dc813b5c32d3c78534d314102ab3b56f1 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 2 Mar 2022 13:08:21 -0500 Subject: [PATCH 04/14] (torch_lazy_loader.py) Add support for materializing from a ZipExtFile --- torch_lazy_loader.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 85e1310e..4a29d0c8 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -3,7 +3,7 @@ from functools import reduce import zipfile import pickle import torch -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union class LazyTensor: @@ -23,13 +23,20 @@ class LazyTensor: def __repr__(self): return self.__view(repr) - def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor: + def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.storage_type(0).dtype nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) - with checkpoint.open(f"archive/data/{self.key}", "r") as f: + if isinstance(checkpoint, zipfile.ZipFile): + f = checkpoint.open(f"archive/data/{self.key}", "r") f.seek(self.storage_offset) + else: + f = checkpoint + try: storage = self.storage_type.from_buffer(f.read(nbytes), "little") + finally: + if isinstance(checkpoint, zipfile.ZipFile): + f.close() storage = torch.serialization._get_restore_location(map_location)(storage, self.location) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor.set_(storage, 0, self.shape, self.stride) From 8e6e04be5f3e91baa3ac23d0f33c49bcc434c58f Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 3 Mar 2022 11:17:59 -0500 Subject: [PATCH 05/14] (torch_lazy_loader.py) Add dematerialized modules setting --- torch_lazy_loader.py | 155 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 1 deletion(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 4a29d0c8..d9358442 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -1,11 +1,62 @@ +''' +This file is AGPL-licensed. + +Some of the code in this file is copied from PyTorch. + +The license for PyTorch is shown below: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +''' + + import contextlib from functools import reduce +import itertools import zipfile import pickle import torch +from torch.nn import Module from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + class LazyTensor: def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type @@ -73,8 +124,77 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): return lazy_storage +# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 +def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}' + .format(key, type(input_param))) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + #param.copy_(input_param) + new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new + if name in self._parameters: # This line is new + self._parameters[name] = new_param # This line is new + if name in persistent_buffers: # This line is new + self._buffers[name] = new_param # This line is new + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.' + .format(key, param.size(), input_param.size(), ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + @contextlib.contextmanager -def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): +def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False): if not enable: yield False return @@ -96,9 +216,42 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): torch.load = torch_load + def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + if callback is not None: + callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + return retval + + torch.load = torch_load + + if dematerialized_modules: + old_linear_init = torch.nn.Linear.__init__ + old_embedding_init = torch.nn.Embedding.__init__ + old_layernorm_init = torch.nn.LayerNorm.__init__ + + def linear_init(self, *args, device=None, **kwargs): + return old_linear_init(self, *args, device="meta", **kwargs) + + def embedding_init(self, *args, device=None, **kwargs): + return old_embedding_init(self, *args, device="meta", **kwargs) + + def layernorm_init(self, *args, device=None, **kwargs): + return old_layernorm_init(self, *args, device="meta", **kwargs) + + torch.nn.Linear.__init__ = linear_init + torch.nn.Embedding.__init__ = embedding_init + torch.nn.LayerNorm.__init__ = layernorm_init + old_load_from_state_dict = torch.nn.Module._load_from_state_dict + torch.nn.Module._load_from_state_dict = _load_from_state_dict + yield True finally: pickle.Unpickler = old_unpickler torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load + if dematerialized_modules: + torch.nn.Linear.__init__ = old_linear_init + torch.nn.Embedding.__init__ = old_embedding_init + torch.nn.LayerNorm.__init__ = old_layernorm_init + torch.nn.Module._load_from_state_dict = old_load_from_state_dict From 24bc0f81ea2d321682358c36a77015974d4836c1 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 3 Mar 2022 19:55:31 -0500 Subject: [PATCH 06/14] Remove duplicate `torch_load` definition --- torch_lazy_loader.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index d9358442..5ff9655b 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -216,14 +216,6 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate torch.load = torch_load - def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): - retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - if callback is not None: - callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - return retval - - torch.load = torch_load - if dematerialized_modules: old_linear_init = torch.nn.Linear.__init__ old_embedding_init = torch.nn.Embedding.__init__ From 1515996fcad66bbf7875862b670013e1c3008fba Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 3 Mar 2022 23:53:40 -0500 Subject: [PATCH 07/14] Fix torch_lazy_loader seek offset calculation --- torch_lazy_loader.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 5ff9655b..d097675f 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -58,18 +58,18 @@ _EXTRA_STATE_KEY_SUFFIX = '_extra_state' class LazyTensor: - def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): + def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type self.key = key self.location = location - self.storage_offset = storage_offset + self.seek_offset = seek_offset self.shape = shape self.stride = stride self.requires_grad = requires_grad self.backward_hooks = backward_hooks def __view(self, f: Callable): - return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, storage_offset={f(self.storage_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" + return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" def __repr__(self): return self.__view(repr) @@ -80,7 +80,7 @@ class LazyTensor: nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) if isinstance(checkpoint, zipfile.ZipFile): f = checkpoint.open(f"archive/data/{self.key}", "r") - f.seek(self.storage_offset) + f.seek(self.seek_offset) else: f = checkpoint try: @@ -118,9 +118,10 @@ class _LazyUnpickler(pickle.Unpickler): def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): - lazy_storage.storage_offset = storage_offset lazy_storage.shape = shape lazy_storage.stride = stride + dtype = lazy_storage.storage_type(0).dtype + lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) return lazy_storage From 58a2c188219b55ba415dfb06dd53b9627321258f Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 4 Mar 2022 00:33:10 -0500 Subject: [PATCH 08/14] Add lazy torch loading support to transformers backend --- aiserver.py | 162 ++++++++++++++++++++++++++++++++++------------ maps/gpt_neo.json | 25 +++++++ maps/gptj.json | 24 +++++++ maps/xglm.json | 26 ++++++++ 4 files changed, 196 insertions(+), 41 deletions(-) create mode 100644 maps/gpt_neo.json create mode 100644 maps/gptj.json create mode 100644 maps/xglm.json diff --git a/aiserver.py b/aiserver.py index 66c0e197..2a22d6b1 100644 --- a/aiserver.py +++ b/aiserver.py @@ -26,6 +26,8 @@ import traceback import threading import markdown import bleach +import itertools +import bisect from collections.abc import Iterable from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List @@ -248,6 +250,7 @@ class vars: newlinemode = "n" quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) debug = False # If set to true, will send debug information to the client for display + lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage utils.vars = vars @@ -335,10 +338,10 @@ def device_list(n_layers, primary=None, selected=None): sep_color = colors.YELLOW print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}") -def device_config(model): +def device_config(config): global breakmodel, generator import breakmodel - n_layers = model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer + n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer if(args.breakmodel_gpulayers is not None): try: breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(','))) @@ -411,22 +414,30 @@ def device_config(model): # If all layers are on the same device, use the old GPU generation mode while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0): breakmodel.gpu_blocks.pop() - if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, model.config.num_layers if hasattr(model.config, "num_layers") else model.config.n_layer)): + if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)): vars.breakmodel = False vars.usegpu = True vars.gpu_device = len(breakmodel.gpu_blocks)-1 - model = model.half().to(vars.gpu_device) - generator = model.generate return if(not breakmodel.gpu_blocks): print("Nothing assigned to a GPU, reverting to CPU only mode") vars.breakmodel = False vars.usegpu = False - model = model.to('cpu').float() + return + +def move_model_to_devices(model): + global generator + + if(not vars.breakmodel): + if(vars.usegpu): + model = model.half().to(vars.gpu_device) + else: + model = model.to('cpu').float() generator = model.generate return - model.half().to('cpu') + + model.half() gc.collect() if(hasattr(model, "transformer")): model.transformer.wte.to(breakmodel.primary_device) @@ -1013,6 +1024,67 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme import transformers.generation_utils from transformers import __version__ as transformers_version + # Lazy loader + import torch_lazy_loader + def get_lazy_load_callback(n_layers): + if not vars.lazy_load: + return + + from tqdm import tqdm + + if "breakmodel" in globals(): + gpu_blocks = breakmodel.gpu_blocks + ram_blocks = ram_blocks = n_layers - sum(gpu_blocks) + cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks)) + else: + ram_blocks = gpu_blocks = cumulative_gpu_blocks = None + + def lazy_load_callback(model_dict, f, **_): + device_map = {} + + for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): + for layer in range(n_layers): + key = _key.format(layer=layer) + if key not in model_dict: + continue + device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks) + device_map[key] = device + + for key, value in model_dict.items(): + if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: + device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" + + with zipfile.ZipFile(f, "r") as z: + try: + last_storage_key = None + f = None + for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + storage_key = model_dict[key].key + if storage_key != last_storage_key: + last_storage_key = storage_key + if isinstance(f, zipfile.ZipExtFile): + f.close() + f = z.open(f"archive/data/{storage_key}") + current_offset = f.tell() + if current_offset != model_dict[key].seek_offset: + f.seek(model_dict[key].seek_offset - current_offset, 1) + device = device_map[key] + #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) + model_dict[key] = model_dict[key].materialize(f, map_location=torch.device(device)) + #print("OK", flush=True) + finally: + if isinstance(f, zipfile.ZipExtFile): + f.close() + + return lazy_load_callback + + if(vars.lazy_load and "model_config" in globals() and vars.model_type in ("gpt_neo", "gptj", "xglm")): + with open(os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")) as f: + lazy_load_spec = json.load(f) + + else: + vars.lazy_load = False + # Temporary fix for XGLM positional embedding issues until # https://github.com/huggingface/transformers/issues/15736 # is resolved @@ -1247,6 +1319,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # If custom GPT2 model was chosen if(vars.model == "GPT2Custom"): + vars.lazy_load = False model_config = open(vars.custmodpth + "/config.json", "r") js = json.load(model_config) with(maybe_use_float16()): @@ -1268,6 +1341,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # feature yet if(vars.model_type == "gpt2"): lowmem = {} + + # If we're using torch_lazy_loader, we need to get breakmodel config + # early so that it knows where to load the individual model tensors + if(vars.lazy_load and vars.hascuda and vars.breakmodel): + device_config(model_config) # Download model from Huggingface if it does not exist, otherwise load locally @@ -1275,43 +1353,43 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if os.path.isdir(vars.model.replace('/', '_')): import shutil shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) - if(os.path.isdir(vars.custmodpth)): - with(maybe_use_float16()): - try: - tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") - try: - model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) - except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) - elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): - with(maybe_use_float16()): - try: - tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - try: - model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) - except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) - else: - try: - tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") - with(maybe_use_float16()): + with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer), dematerialized_modules=True): + if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time + lowmem = {} + if(os.path.isdir(vars.custmodpth)): + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem) + elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): + try: + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem) - - if not args.colab: - import shutil - model = model.half() - model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) - tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) - shutil.rmtree("cache/") + + if not args.colab: + import shutil + model = model.half() + model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) + tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) + shutil.rmtree("cache/") if(vars.hascuda): if(vars.usegpu): @@ -1320,7 +1398,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme generator = model.generate elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) vars.modeldim = get_hidden_size_from_model(model) - device_config(model) + if(not vars.lazy_load): + device_config(model.config) + move_model_to_devices(model) else: model = model.to('cpu').float() vars.modeldim = get_hidden_size_from_model(model) diff --git a/maps/gpt_neo.json b/maps/gpt_neo.json new file mode 100644 index 00000000..a93fb26e --- /dev/null +++ b/maps/gpt_neo.json @@ -0,0 +1,25 @@ +{ + "static_weights": { + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}} + }, + "layer_weights": { + "transformer.h.{layer}.attn.attention.bias": {}, + "transformer.h.{layer}.attn.attention.masked_bias": {}, + "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, + "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, + "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, + "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, + "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + } +} diff --git a/maps/gptj.json b/maps/gptj.json new file mode 100644 index 00000000..51a788d7 --- /dev/null +++ b/maps/gptj.json @@ -0,0 +1,24 @@ +{ + "static_weights": { + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}, + "lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}} + }, + "layer_weights": { + "transformer.h.{layer}.attn.bias": {}, + "transformer.h.{layer}.attn.masked_bias": {}, + "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, + "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, + "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, + "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}} + } +} diff --git a/maps/xglm.json b/maps/xglm.json new file mode 100644 index 00000000..beb90985 --- /dev/null +++ b/maps/xglm.json @@ -0,0 +1,26 @@ +{ + "static_weights": { + "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, + "model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}} + }, + "layer_weights": { + "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, + "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}}, + "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, + "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}}, + "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, + "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}}, + "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, + "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, + "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, + "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, + "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, + "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, + "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, + "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + } +} From a1fedca2c87c352ffeddcafbdfe53865a00d52ff Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 4 Mar 2022 11:11:33 -0500 Subject: [PATCH 09/14] Use lazy loading automatically if a config file exists for the model --- aiserver.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index f8b67693..4513d486 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1080,8 +1080,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme return lazy_load_callback - if(vars.lazy_load and "model_config" in globals() and vars.model_type in ("gpt_neo", "gptj", "xglm")): - with open(os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")) as f: + lazy_load_config_path = os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json") + if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)): + with open(lazy_load_config_path) as f: lazy_load_spec = json.load(f) else: From 86ac562b0c25197fe8805fbdd01ba719731dae0a Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 5 Mar 2022 11:31:34 -0500 Subject: [PATCH 10/14] Lazy loader should convert model tensors to float16 before moving them --- aiserver.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 19bd9be9..e78211ae 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1028,7 +1028,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # Lazy loader import torch_lazy_loader - def get_lazy_load_callback(n_layers): + def get_lazy_load_callback(n_layers, convert_to_float16=True): if not vars.lazy_load: return @@ -1072,7 +1072,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme f.seek(model_dict[key].seek_offset - current_offset, 1) device = device_map[key] #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) - model_dict[key] = model_dict[key].materialize(f, map_location=torch.device(device)) + model_dict[key] = model_dict[key].materialize(f, map_location="cpu") + if convert_to_float16 and model_dict[key].dtype is torch.float32: + model_dict[key] = model_dict[key].to(torch.float16) + model_dict[key] = model_dict[key].to(device) #print("OK", flush=True) finally: if isinstance(f, zipfile.ZipExtFile): From 0a258a6282c2a7a3e825baa50e3cb74eed279eb6 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 5 Mar 2022 12:33:33 -0500 Subject: [PATCH 11/14] Support for loading HF models on TPU with `--colab_tpu` --- aiserver.py | 31 +++--- maps/gpt_neo.json | 37 +++++--- maps/gptj.json | 34 ++++--- maps/xglm.json | 43 +++++---- tpu_mtj_backend.py | 229 +++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 306 insertions(+), 68 deletions(-) diff --git a/aiserver.py b/aiserver.py index e78211ae..2d1998e6 100644 --- a/aiserver.py +++ b/aiserver.py @@ -695,7 +695,7 @@ def spRequest(filename): vars.sp_length = tensor.shape[-2] vars.spmeta["n_tokens"] = vars.sp_length - if(vars.model in ("TPUMeshTransformerGPTJ",)): + if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) @@ -730,6 +730,7 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.") +parser.add_argument("--colab_tpu", action='store_true', help="If you're running KoboldAI in a Google Colab TPU instance, enable this to load Hugging Face models onto the TPU.") parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.") parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)") parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console") @@ -783,7 +784,7 @@ else: getModelSelection(mainmenu) # If transformers model was selected & GPU available, ask to use CPU or GPU -if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): vars.allowsp = True # Test for GPU support import torch @@ -822,6 +823,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme elif(vars.model_type == "not_found"): print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" + +if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): loadmodelsettings() loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") @@ -1014,7 +1017,7 @@ socketio = SocketIO(app, async_method="eventlet") print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline -if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer @@ -1523,9 +1526,9 @@ else: tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") loadsettings() # Load the TPU backend if requested - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) - if not vars.custmodpth or not os.path.isdir(vars.custmodpth): + if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend tpu_mtj_backend.vars = vars @@ -1537,7 +1540,7 @@ else: vars.allowsp = True loadmodelsettings() loadsettings() - tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig) + tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=args.colab_tpu, **vars.modelconfig) vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer else: @@ -2068,7 +2071,7 @@ def lua_get_modeltype(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): + if(not args.colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): hidden_size = get_hidden_size_from_model(model) if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): return "gpt2" @@ -2084,7 +2087,7 @@ def lua_get_modeltype(): return "gpt-neo-1.3B" if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)): return "gpt-neo-2.7B" - if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): + if(vars.model in ("EleutherAI/gpt-j-6B",) or ((args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): return "gpt-j-6B" return "unknown" @@ -2097,7 +2100,7 @@ def lua_get_modelbackend(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(vars.model in ("TPUMeshTransformerGPTJ",)): + if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): return "mtj" return "transformers" @@ -3044,22 +3047,22 @@ def calcsubmit(txt): if(vars.model != "InferKit"): subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) if(actionlen == 0): - if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) else: - if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(vars.model == "TPUMeshTransformerGPTJ"): + elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) # For InferKit web API @@ -5071,7 +5074,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): file.close() # Precompile TPU backend if required -if(vars.model in ("TPUMeshTransformerGPTJ",)): +if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): soft_tokens = tpumtjgetsofttokens() if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): threading.Thread( diff --git a/maps/gpt_neo.json b/maps/gpt_neo.json index a93fb26e..fa2d4084 100644 --- a/maps/gpt_neo.json +++ b/maps/gpt_neo.json @@ -1,25 +1,32 @@ { + "mtj_compat": "neo", + "mtj_pe": "fixed", + "mtj_config_map": { + "d_model": "hidden_size", + "n_heads": "num_heads", + "layers": "num_layers" + }, "static_weights": { - "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, - "transformer.wpe.weight": {"mtj": {"module": "embedding_shard/~/pos_embs", "param": "w", "axis": 2, "transforms": ["transpose"]}}, - "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, - "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}} + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "transformer.wpe.weight": {"mtj": {"module": "embedding_shard", "param": "pos_embs", "transforms": ["no_transpose"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}} }, "layer_weights": { "transformer.h.{layer}.attn.attention.bias": {}, "transformer.h.{layer}.attn.attention.masked_bias": {}, - "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "transformer.h.{layer}.attn.attention.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "transformer.h.{layer}.attn.attention.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "transformer.h.{layer}.attn.attention.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "transformer.h.{layer}.attn.attention.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, "transformer.h.{layer}.attn.attention.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, - "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, - "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, - "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.mlp.c_fc.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "transformer.h.{layer}.mlp.c_fc.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "transformer.h.{layer}.mlp.c_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, "transformer.h.{layer}.mlp.c_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, - "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, - "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, - "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, - "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "transformer.h.{layer}.ln_2.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "transformer.h.{layer}.ln_2.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} } } diff --git a/maps/gptj.json b/maps/gptj.json index 51a788d7..8e0bc9da 100644 --- a/maps/gptj.json +++ b/maps/gptj.json @@ -1,24 +1,32 @@ { + "mtj_compat": "j", + "mtj_pe": "rotary", + "mtj_config_map": { + "pe_rotary_dims": ["rotary_dim", 64], + "d_model": "n_embd", + "n_heads": "n_head", + "layers": "n_layer" + }, "static_weights": { - "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, + "transformer.wte.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, "transformer.wte.bias": {"mtj": {"module": "embedding_shard/~/linear", "param": "b"}}, - "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, - "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, - "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}}, + "transformer.ln_f.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "transformer.ln_f.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}}, "lm_head.bias": {"mtj": {"module": "projection_shard/~/linear", "param": "b"}} }, "layer_weights": { "transformer.h.{layer}.attn.bias": {}, "transformer.h.{layer}.attn.masked_bias": {}, - "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, - "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, - "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, - "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, - "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "transformer.h.{layer}.attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "transformer.h.{layer}.attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "transformer.h.{layer}.attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "transformer.h.{layer}.attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_in.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "transformer.h.{layer}.mlp.fc_in.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "transformer.h.{layer}.mlp.fc_out.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, "transformer.h.{layer}.mlp.fc_out.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, - "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, - "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}} + "transformer.h.{layer}.ln_1.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "transformer.h.{layer}.ln_1.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}} } } diff --git a/maps/xglm.json b/maps/xglm.json index beb90985..3ba4b1f2 100644 --- a/maps/xglm.json +++ b/maps/xglm.json @@ -1,26 +1,33 @@ { + "mtj_compat": "fairseq_lm", + "mtj_pe": "fairseq_sinusoidal", + "mtj_config_map": { + "d_model": "d_model", + "n_heads": "attention_heads", + "layers": "num_layers" + }, "static_weights": { - "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1, "transforms": ["transpose", "vocab_pad"]}}, - "model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "scale"}}, - "model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/layer_norm", "param": "offset"}}, - "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "axis": 2, "transforms": ["vocab_pad"]}} + "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, + "model.replicated_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "model.replicated_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}} }, "layer_weights": { - "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w", "axis": 2}}, - "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b", "axis": 1}}, - "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w", "axis": 2}}, - "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b", "axis": 1}}, - "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w", "axis": 2}}, - "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b", "axis": 1}}, - "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w", "axis": 1}}, + "model.layers.{layer}.self_attn.q_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear", "param": "w"}}, + "model.layers.{layer}.self_attn.q_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear", "param": "b"}}, + "model.layers.{layer}.self_attn.v_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "w"}}, + "model.layers.{layer}.self_attn.v_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_1", "param": "b"}}, + "model.layers.{layer}.self_attn.k_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "w"}}, + "model.layers.{layer}.self_attn.k_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_2", "param": "b"}}, + "model.layers.{layer}.self_attn.out_proj.weight": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "w"}}, "model.layers.{layer}.self_attn.out_proj.bias": {"mtj": {"module": "layer_{layer}/~/linear_3", "param": "b", "transforms": ["divide_by_shards"]}}, - "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w", "axis": 2}}, - "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b", "axis": 1}}, - "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w", "axis": 1}}, + "model.layers.{layer}.fc1.weight": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "w"}}, + "model.layers.{layer}.fc1.bias": {"mtj": {"module": "layer_{layer}/~/linear_4", "param": "b"}}, + "model.layers.{layer}.fc2.weight": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "w"}}, "model.layers.{layer}.fc2.bias": {"mtj": {"module": "layer_{layer}/~/linear_5", "param": "b", "transforms": ["divide_by_shards"]}}, - "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "scale"}}, - "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm", "param": "offset"}}, - "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "scale"}}, - "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/layer_norm_1", "param": "offset"}} + "model.layers.{layer}.self_attn_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "scale"}}, + "model.layers.{layer}.self_attn_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm", "param": "offset"}}, + "model.layers.{layer}.final_layer_norm.weight": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "scale"}}, + "model.layers.{layer}.final_layer_norm.bias": {"mtj": {"module": "layer_{layer}/~/replicated_layer_norm_1", "param": "offset"}} } } diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index a78f93f2..6f5500b7 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar import time import os +import sys +import json +import zipfile import requests import random import jax @@ -41,9 +44,10 @@ import jax.numpy as jnp import numpy as np import optax import haiku as hk -import transformers +from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM from mesh_transformer.checkpoint import read_ckpt_lowmem -from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard +from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor +from mesh_transformer.util import to_bf16 params: Dict[str, Any] = {} @@ -776,7 +780,26 @@ def infer_static( return samples -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: +def reshard_reverse(x, total_shards, old_shape): + assert len(x.shape) != 1 + if len(x.shape) == 2: + if old_shape[1] == x.shape[1]: + out = x[0:1].tile((total_shards, 1)) + else: + out = x.reshape(old_shape) + elif len(x.shape) == 3: + if x.shape[0] * x.shape[2] == old_shape[2]: + out = x.reshape(old_shape) + elif x.shape[0] * x.shape[1] == old_shape[1]: + out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2)) + else: + assert False + else: + assert False + return out + + +def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params default_params = { @@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) "tokenizer": "gpt2", } params = kwargs + + # Try to convert HF config.json to MTJ config + if hf_checkpoint: + spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json") + if not os.path.isfile(spec_path): + raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}") + with open(spec_path) as f: + lazy_load_spec = json.load(f) + + if "mtj_compat" in lazy_load_spec: + params["compat"] = lazy_load_spec["mtj_compat"] + if "mtj_pe" in lazy_load_spec: + params["pe"] = lazy_load_spec["mtj_pe"] + for k, v in lazy_load_spec.get("mtj_config_map", {}).items(): + if type(v) is not list: + params[k] = params[v] + continue + for i in range(len(v)): + if i == len(v) - 1: + params[k] = v[i] + elif v[i] in params: + params[k] = params[v[i]] + break + + params["n_vocab"] = params["vocab_size"] + + if "activation_function" in params: + params["activation"] = params["activation_function"] + + # Both the number of attention heads in the model and the embedding + # dimension of the model need to be divisible by the number of TPU cores + # that we use, and JAX also requires the number of TPU cores used to be + # an even number if we're using more than one core, so logically we try + # to pick the largest possible even number of TPU cores such that the + # number of attention heads and embedding dimension are both divisible + # by the number of TPU cores, and fall back to one core if an even + # number of TPU cores is not possible. + for c in (8, 6, 4, 2, 1): + if 0 == params["n_heads"] % c == params["d_model"] % c: + params["cores_per_replica"] = c + break + + # The vocabulary size of the model also has to be divisible by the + # number of TPU cores, so we pad the vocabulary with the minimum + # possible number of dummy tokens such that it's divisible. + params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"]) + if "compat" in params: default_params["compat"] = params["compat"] if default_params["compat"] == "fairseq_lm": @@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) params[param] = default_params[param] # Load tokenizer - if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): - raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") - tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) - tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) + if not hf_checkpoint: + if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): + raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") + tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) + tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) # Disable JAX warnings about these two functions having been renamed jax.host_count = jax.process_count @@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) path += "/" network = PenalizingCausalTransformer(params, dematerialized=True) - network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) + + if not hf_checkpoint: + network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) + network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) + return + + # Convert from HF checkpoint + + move_xmap = jax.experimental.maps.xmap( + fun=lambda x, _: to_bf16(x), + in_axes=(["shard", ...], ["batch", ...]), + out_axes=["shard", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'} + ) + + model_spec = {} + for key, spec in lazy_load_spec.get("static_weights", {}).items(): + if spec.get("mtj") is not None: + model_spec[key] = spec["mtj"].copy() + model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"] + for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): + for layer in range(params["layers"]): + if spec.get("mtj") is not None: + key = _key.format(layer=layer) + model_spec[key] = spec["mtj"].copy() + model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer) + + import torch_lazy_loader + import torch + from tqdm import tqdm + + def callback(model_dict, f, **_): + with zipfile.ZipFile(f, "r") as z: + try: + last_storage_key = None + f = None + print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") + for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"): + + # Some model weights are used by transformers but not by MTJ. + # We have to materialize these weights anyways because + # transformers will throw a tantrum otherwise. To attain + # the least possible memory usage, we create them as meta + # tensors, which don't take up any actual CPU or TPU memory. + if key not in model_spec: + model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta") + continue + + storage_key = model_dict[key].key + if storage_key != last_storage_key: + last_storage_key = storage_key + if isinstance(f, zipfile.ZipExtFile): + f.close() + f = z.open(f"archive/data/{storage_key}") + current_offset = f.tell() + if current_offset != model_dict[key].seek_offset: + f.seek(model_dict[key].seek_offset - current_offset, 1) + spec = model_spec[key] + transforms = set(spec.get("transforms", ())) + if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor): + error = f"Duplicate key {repr(key)}" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + tensor = model_dict[key].materialize(f, map_location="cpu") + model_dict[key] = tensor.to("meta") + + # MTJ requires certain mathematical operations to be performed + # on tensors in order for them to be in the correct format + if "divide_by_shards" in transforms: + tensor /= params["cores_per_replica"] + if "vocab_pad" in transforms: + tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"])) + if "no_transpose" not in transforms: + tensor = tensor.T + tensor.unsqueeze_(0) + + # Shard the tensor so that parts of the tensor can be used + # on different TPU cores + network.state["params"][spec["module"]][spec["param"]] = move_xmap( + jnp.array( + reshard_reverse( + tensor, + params["cores_per_replica"], + network.state["params"][spec["module"]][spec["param"]].shape, + ), + dtype=jnp.bfloat16, + ), + np.empty(params["cores_per_replica"]), + ) + + # Check for tensors that MTJ needs that were not provided in the + # HF model + for mk, mv in network.state["params"].items(): + for pk, pv in mv.items(): + if isinstance(pv, PlaceholderTensor): + # The transformers GPT-J models apparently do not + # have embedding bias, whereas MTJ GPT-J models do, + # so we have to supplement an embedding bias tensor + # by creating a tensor with the necessary shape, filled + # with zeros. + if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b": + mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"])) + + else: + error = f"{mk} {pk} could not be found in the model checkpoint" + print("\n\nERROR: " + error, file=sys.stderr) + raise RuntimeError(error) + finally: + if isinstance(f, zipfile.ZipExtFile): + f.close() + + if os.path.isdir(vars.model.replace('/', '_')): + import shutil + shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) + with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True): + if(os.path.isdir(vars.custmodpth)): + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") + elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): + try: + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") + try: + model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache") + except ValueError as e: + model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") + network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) From 4625158d30da3a2a1236a9e0a28a936fb064093d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 5 Mar 2022 12:56:42 -0500 Subject: [PATCH 12/14] Fix typo in previous commit --- maps/xglm.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/maps/xglm.json b/maps/xglm.json index 3ba4b1f2..65ab5e5e 100644 --- a/maps/xglm.json +++ b/maps/xglm.json @@ -8,8 +8,8 @@ }, "static_weights": { "model.embed_tokens.weight": {"mtj": {"module": "embedding_shard/~/linear", "param": "w", "transforms": ["no_transpose", "vocab_pad"]}}, - "model.replicated_layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, - "model.replicated_layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, + "model.layer_norm.weight": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale"}}, + "model.layer_norm.bias": {"mtj": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset"}}, "lm_head.weight": {"mtj": {"module": "projection_shard/~/linear", "param": "w", "transforms": ["vocab_pad"]}} }, "layer_weights": { From 2e19ea1bb67ed722a0518c5ffff7157dc2372dc6 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 5 Mar 2022 14:07:23 -0500 Subject: [PATCH 13/14] Auto detect if we're in a Colab TPU instance --- aiserver.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/aiserver.py b/aiserver.py index 2d1998e6..711345cd 100644 --- a/aiserver.py +++ b/aiserver.py @@ -253,6 +253,7 @@ class vars: quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) debug = False # If set to true, will send debug information to the client for display lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage + use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" # Whether or not we're in a Colab TPU instance and are going to use the TPU rather than the CPU utils.vars = vars @@ -695,7 +696,7 @@ def spRequest(filename): vars.sp_length = tensor.shape[-2] vars.spmeta["n_tokens"] = vars.sp_length - if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): + if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) @@ -730,7 +731,6 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.") -parser.add_argument("--colab_tpu", action='store_true', help="If you're running KoboldAI in a Google Colab TPU instance, enable this to load Hugging Face models onto the TPU.") parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.") parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)") parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console") @@ -768,6 +768,9 @@ if args.ngrok: if args.host: vars.host = True; +if args.cpu: + vars.use_colab_tpu = False + vars.smandelete = vars.host == args.override_delete vars.smanrename = vars.host == args.override_rename @@ -824,7 +827,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") vars.model_type = "gpt_neo" -if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): loadmodelsettings() loadsettings() print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") @@ -1017,7 +1020,7 @@ socketio = SocketIO(app, async_method="eventlet") print("{0}OK!{1}".format(colors.GREEN, colors.END)) # Start transformers and create pipeline -if(not args.colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): +if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer @@ -1526,7 +1529,7 @@ else: tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") loadsettings() # Load the TPU backend if requested - elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) if vars.model == "TPUMeshTransformerGPTJ" and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") @@ -1540,7 +1543,7 @@ else: vars.allowsp = True loadmodelsettings() loadsettings() - tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=args.colab_tpu, **vars.modelconfig) + tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model != "TPUMeshTransformerGPTJ" and vars.use_colab_tpu, **vars.modelconfig) vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer else: @@ -2071,7 +2074,7 @@ def lua_get_modeltype(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(not args.colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): + if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ",) and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))): hidden_size = get_hidden_size_from_model(model) if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)): return "gpt2" @@ -2087,7 +2090,7 @@ def lua_get_modeltype(): return "gpt-neo-1.3B" if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)): return "gpt-neo-2.7B" - if(vars.model in ("EleutherAI/gpt-j-6B",) or ((args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): + if(vars.model in ("EleutherAI/gpt-j-6B",) or ((vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)): return "gpt-j-6B" return "unknown" @@ -2100,7 +2103,7 @@ def lua_get_modelbackend(): return "readonly" if(vars.model in ("Colab", "OAI", "InferKit")): return "api" - if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): + if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): return "mtj" return "transformers" @@ -3047,22 +3050,22 @@ def calcsubmit(txt): if(vars.model != "InferKit"): subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt) if(actionlen == 0): - if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) else: - if(not args.colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): + if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - elif(args.colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): + elif(vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) # For InferKit web API @@ -5074,7 +5077,7 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): file.close() # Precompile TPU backend if required -if(args.colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): +if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ",)): soft_tokens = tpumtjgetsofttokens() if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): threading.Thread( From 373f7b9bd5f0614fe3e26a7e0d818efd85abf9e1 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 5 Mar 2022 14:30:26 -0500 Subject: [PATCH 14/14] Don't convert tensors to float16 if using CPU-only mode --- aiserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 8c882646..8444080a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1079,7 +1079,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Re device = device_map[key] #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) model_dict[key] = model_dict[key].materialize(f, map_location="cpu") - if convert_to_float16 and model_dict[key].dtype is torch.float32: + if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32: model_dict[key] = model_dict[key].to(torch.float16) model_dict[key] = model_dict[key].to(device) #print("OK", flush=True)