diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index b8870cee..b61f9be6 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -56,6 +56,7 @@ import collections import _codecs import os from typing import Any, Callable, Dict, Optional, Tuple, Type +import accelerate from torch.nn import Module from torch.storage import UntypedStorage @@ -64,6 +65,7 @@ from torch.storage import UntypedStorage # support it yet. try: import safetensors + HAS_SAFETENSORS = True except ModuleNotFoundError: HAS_SAFETENSORS = False @@ -71,17 +73,6 @@ except ModuleNotFoundError: import utils from logger import logger -# Accelerate is used to load with empty modules. TPU version doesn't come -# packaged with it so we use an in-house solution in that case -try: - import accelerate - HAS_ACCELERATE = True -except ModuleNotFoundError: - HAS_ACCELERATE = False - -_EXTRA_STATE_KEY_SUFFIX = "_extra_state" - - # Storage of zipfile handles for each shard torch_checkpoint_file_handles = {} @@ -331,116 +322,6 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): ) return lazy_storage -# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 -def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): - for hook in self._load_state_dict_pre_hooks.values(): - hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - persistent_buffers = { - k: v - for k, v in self._buffers.items() - if k not in self._non_persistent_buffers_set - } - local_name_params = itertools.chain( - self._parameters.items(), persistent_buffers.items() - ) - local_state = {k: v for k, v in local_name_params if v is not None} - - for name, param in local_state.items(): - key = prefix + name - if key in state_dict: - input_param = state_dict[key] - if not torch.overrides.is_tensor_like(input_param): - error_msgs.append( - 'While copying the parameter named "{}", ' - "expected torch.Tensor or Tensor-like object from checkpoint but " - "received {}".format(key, type(input_param)) - ) - continue - - # This is used to avoid copying uninitialized parameters into - # non-lazy modules, since they dont have the hook to do the checks - # in such case, it will error when accessing the .shape attribute. - is_param_lazy = torch.nn.parameter.is_lazy(param) - # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if ( - not is_param_lazy - and len(param.shape) == 0 - and len(input_param.shape) == 1 - ): - input_param = input_param[0] - - if not is_param_lazy and input_param.shape != param.shape: - # local shape should match the one in checkpoint - error_msgs.append( - "size mismatch for {}: copying a param with shape {} from checkpoint, " - "the shape in current model is {}.".format( - key, input_param.shape, param.shape - ) - ) - continue - try: - with torch.no_grad(): - # param.copy_(input_param) - new_param = torch.nn.Parameter( - input_param, requires_grad=param.requires_grad - ) # This line is new - if name in self._parameters: # This line is new - self._parameters[name] = new_param # This line is new - if name in persistent_buffers: # This line is new - self._buffers[name] = new_param # This line is new - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format( - key, param.size(), input_param.size(), ex.args - ) - ) - elif strict: - missing_keys.append(key) - - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - hasattr(Module, "set_extra_state") - and getattr(self.__class__, "set_extra_state", Module.set_extra_state) - is not Module.set_extra_state - ): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: - if extra_state_key in state_dict: - self.set_extra_state(state_dict[extra_state_key]) - elif strict: - missing_keys.append(extra_state_key) - elif strict and (extra_state_key in state_dict): - unexpected_keys.append(extra_state_key) - - if strict: - for key in state_dict.keys(): - if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix) :] - input_name = input_name.split(".", 1)[ - 0 - ] # get the name of param/buffer/child - if input_name not in self._modules and input_name not in local_state: - unexpected_keys.append(key) - def safetensors_load_tensor_independently( checkpoint_file: str, tensor_key: str, device: Any @@ -572,29 +453,8 @@ def use_lazy_load( patch_safetensors(callback) if dematerialized_modules: - if HAS_ACCELERATE: - init_empty_weights = accelerate.init_empty_weights() - init_empty_weights.__enter__() - else: - # TPU doesn't use accelerate package - old_linear_init = torch.nn.Linear.__init__ - old_embedding_init = torch.nn.Embedding.__init__ - old_layernorm_init = torch.nn.LayerNorm.__init__ - - def linear_init(self, *args, device=None, **kwargs): - return old_linear_init(self, *args, device="meta", **kwargs) - - def embedding_init(self, *args, device=None, **kwargs): - return old_embedding_init(self, *args, device="meta", **kwargs) - - def layernorm_init(self, *args, device=None, **kwargs): - return old_layernorm_init(self, *args, device="meta", **kwargs) - - torch.nn.Linear.__init__ = linear_init - torch.nn.Embedding.__init__ = embedding_init - torch.nn.LayerNorm.__init__ = layernorm_init - old_load_from_state_dict = torch.nn.Module._load_from_state_dict - torch.nn.Module._load_from_state_dict = _load_from_state_dict + init_empty_weights = accelerate.init_empty_weights() + init_empty_weights.__enter__() with use_custom_unpickler(_LazyUnpickler): yield True @@ -609,13 +469,7 @@ def use_lazy_load( ) if dematerialized_modules: - if HAS_ACCELERATE: - init_empty_weights.__exit__(None, None, None) - else: - torch.nn.Linear.__init__ = old_linear_init - torch.nn.Embedding.__init__ = old_embedding_init - torch.nn.LayerNorm.__init__ = old_layernorm_init - torch.nn.Module._load_from_state_dict = old_load_from_state_dict + init_empty_weights.__exit__(None, None, None) def post_load_cleanup() -> None: diff --git a/requirements_mtj.txt b/requirements_mtj.txt index e3f91282..b3521d03 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -34,3 +34,4 @@ ijson ftfy pydub sentencepiece +accelerate==0.18.0 \ No newline at end of file