From 31a3046a189cc05984c1569888ffd3c1f1468a10 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 3 Jul 2023 17:07:18 -0500 Subject: [PATCH] Load empty modules without accelerate --- modeling/lazy_loader.py | 162 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 154 insertions(+), 8 deletions(-) diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index c14c7967..b8870cee 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE. import contextlib from functools import reduce +import itertools import time import zipfile import pickle @@ -56,13 +57,13 @@ import _codecs import os from typing import Any, Callable, Dict, Optional, Tuple, Type +from torch.nn import Module from torch.storage import UntypedStorage # Safetensors is a dependency for the local version, TPU/Colab doesn't # support it yet. try: import safetensors - HAS_SAFETENSORS = True except ModuleNotFoundError: HAS_SAFETENSORS = False @@ -70,6 +71,16 @@ except ModuleNotFoundError: import utils from logger import logger +# Accelerate is used to load with empty modules. TPU version doesn't come +# packaged with it so we use an in-house solution in that case +try: + import accelerate + HAS_ACCELERATE = True +except ModuleNotFoundError: + HAS_ACCELERATE = False + +_EXTRA_STATE_KEY_SUFFIX = "_extra_state" + # Storage of zipfile handles for each shard torch_checkpoint_file_handles = {} @@ -145,7 +156,6 @@ class TorchLazyTensor(LazyTensor): map_location=None, no_grad=True, ) -> torch.Tensor: - checkpoint = torch_checkpoint_file_handles[self.file_name] filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0] @@ -321,6 +331,116 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): ) return lazy_storage +# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 +def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + for hook in self._load_state_dict_pre_hooks.values(): + hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), persistent_buffers.items() + ) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + ): + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, param.shape + ) + ) + continue + try: + with torch.no_grad(): + # param.copy_(input_param) + new_param = torch.nn.Parameter( + input_param, requires_grad=param.requires_grad + ) # This line is new + if name in self._parameters: # This line is new + self._parameters[name] = new_param # This line is new + if name in persistent_buffers: # This line is new + self._buffers[name] = new_param # This line is new + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format( + key, param.size(), input_param.size(), ex.args + ) + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + hasattr(Module, "set_extra_state") + and getattr(self.__class__, "set_extra_state", Module.set_extra_state) + is not Module.set_extra_state + ): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[ + 0 + ] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + def safetensors_load_tensor_independently( checkpoint_file: str, tensor_key: str, device: Any @@ -360,7 +480,6 @@ def patch_safetensors(callback): tensors[key] = None for key in tensors.keys(): - tensors[key] = SafetensorsLazyTensor( checkpoint_file=checkpoint_file, key=key, @@ -453,10 +572,29 @@ def use_lazy_load( patch_safetensors(callback) if dematerialized_modules: - import accelerate + if HAS_ACCELERATE: + init_empty_weights = accelerate.init_empty_weights() + init_empty_weights.__enter__() + else: + # TPU doesn't use accelerate package + old_linear_init = torch.nn.Linear.__init__ + old_embedding_init = torch.nn.Embedding.__init__ + old_layernorm_init = torch.nn.LayerNorm.__init__ - init_empty_weights = accelerate.init_empty_weights() - init_empty_weights.__enter__() + def linear_init(self, *args, device=None, **kwargs): + return old_linear_init(self, *args, device="meta", **kwargs) + + def embedding_init(self, *args, device=None, **kwargs): + return old_embedding_init(self, *args, device="meta", **kwargs) + + def layernorm_init(self, *args, device=None, **kwargs): + return old_layernorm_init(self, *args, device="meta", **kwargs) + + torch.nn.Linear.__init__ = linear_init + torch.nn.Embedding.__init__ = embedding_init + torch.nn.LayerNorm.__init__ = layernorm_init + old_load_from_state_dict = torch.nn.Module._load_from_state_dict + torch.nn.Module._load_from_state_dict = _load_from_state_dict with use_custom_unpickler(_LazyUnpickler): yield True @@ -471,14 +609,22 @@ def use_lazy_load( ) if dematerialized_modules: - init_empty_weights.__exit__(None, None, None) + if HAS_ACCELERATE: + init_empty_weights.__exit__(None, None, None) + else: + torch.nn.Linear.__init__ = old_linear_init + torch.nn.Embedding.__init__ = old_embedding_init + torch.nn.LayerNorm.__init__ = old_layernorm_init + torch.nn.Module._load_from_state_dict = old_load_from_state_dict def post_load_cleanup() -> None: """Close dangling file pointers and clear caches after the load is complete.""" global torch_checkpoint_file_handles - logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}") + logger.debug( + f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}" + ) CheckpointChunkCache.clear(unload_model=True) # Bar is initialized in