From 6b83944e9be8e9687ead7dc0c8e98e7759a8be92 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 3 Jul 2023 20:58:31 -0500 Subject: [PATCH] Use VE's patched load_from_state_dict on TPU for loading empty weights --- modeling/lazy_loader.py | 107 ++++++++++++++++++++++++++++++++++++++-- requirements_mtj.txt | 3 +- 2 files changed, 104 insertions(+), 6 deletions(-) diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index b61f9be6..4dcbe392 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -56,7 +56,6 @@ import collections import _codecs import os from typing import Any, Callable, Dict, Optional, Tuple, Type -import accelerate from torch.nn import Module from torch.storage import UntypedStorage @@ -70,6 +69,12 @@ try: except ModuleNotFoundError: HAS_SAFETENSORS = False +try: + import accelerate + USE_TPU_EMPTY_MODULE_METHOD = False +except ModuleNotFoundError: + USE_TPU_EMPTY_MODULE_METHOD = True + import utils from logger import logger @@ -400,6 +405,72 @@ def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler pickle.Unpickler = old_unpickler pickle.load = old_pickle_load +def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}' + .format(key, type(input_param))) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + #param.copy_(input_param) + new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new + if name in self._parameters: # This line is new + self._parameters[name] = new_param # This line is new + if name in persistent_buffers: # This line is new + self._buffers[name] = new_param # This line is new + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.' + .format(key, param.size(), input_param.size(), ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + "_extra_state" + if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) @contextlib.contextmanager def use_lazy_load( @@ -453,8 +524,30 @@ def use_lazy_load( patch_safetensors(callback) if dematerialized_modules: - init_empty_weights = accelerate.init_empty_weights() - init_empty_weights.__enter__() + # Most devices can just use Accelerate's implementation, but the Transformers on + # the TPU complains about emptied weights unless we use VE's custom patches + if not USE_TPU_EMPTY_MODULE_METHOD: + init_empty_weights = accelerate.init_empty_weights() + init_empty_weights.__enter__() + else: + old_linear_init = torch.nn.Linear.__init__ + old_embedding_init = torch.nn.Embedding.__init__ + old_layernorm_init = torch.nn.LayerNorm.__init__ + + def linear_init(self, *args, device=None, **kwargs): + return old_linear_init(self, *args, device="meta", **kwargs) + + def embedding_init(self, *args, device=None, **kwargs): + return old_embedding_init(self, *args, device="meta", **kwargs) + + def layernorm_init(self, *args, device=None, **kwargs): + return old_layernorm_init(self, *args, device="meta", **kwargs) + + torch.nn.Linear.__init__ = linear_init + torch.nn.Embedding.__init__ = embedding_init + torch.nn.LayerNorm.__init__ = layernorm_init + old_load_from_state_dict = torch.nn.Module._load_from_state_dict + torch.nn.Module._load_from_state_dict = _load_from_state_dict with use_custom_unpickler(_LazyUnpickler): yield True @@ -469,7 +562,13 @@ def use_lazy_load( ) if dematerialized_modules: - init_empty_weights.__exit__(None, None, None) + if not USE_TPU_EMPTY_MODULE_METHOD: + init_empty_weights.__exit__(None, None, None) + else: + torch.nn.Linear.__init__ = old_linear_init + torch.nn.Embedding.__init__ = old_embedding_init + torch.nn.LayerNorm.__init__ = old_layernorm_init + torch.nn.Module._load_from_state_dict = old_load_from_state_dict def post_load_cleanup() -> None: diff --git a/requirements_mtj.txt b/requirements_mtj.txt index b3521d03..fe2afbdd 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -33,5 +33,4 @@ flask_compress ijson ftfy pydub -sentencepiece -accelerate==0.18.0 \ No newline at end of file +sentencepiece \ No newline at end of file