Merge pull request #386 from one-some/accelerate-offloading

Use VE's patched load_from_state_dict on TPU for loading empty weights
This commit is contained in:
henk717
2023-07-06 14:54:54 +02:00
committed by GitHub
2 changed files with 104 additions and 6 deletions

View File

@@ -56,7 +56,6 @@ import collections
import _codecs import _codecs
import os import os
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
import accelerate
from torch.nn import Module from torch.nn import Module
from torch.storage import UntypedStorage from torch.storage import UntypedStorage
@@ -70,6 +69,12 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_SAFETENSORS = False HAS_SAFETENSORS = False
try:
import accelerate
USE_TPU_EMPTY_MODULE_METHOD = False
except ModuleNotFoundError:
USE_TPU_EMPTY_MODULE_METHOD = True
import utils import utils
from logger import logger from logger import logger
@@ -400,6 +405,72 @@ def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler
pickle.Unpickler = old_unpickler pickle.Unpickler = old_unpickler
pickle.load = old_pickle_load pickle.load = old_pickle_load
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.format(key, type(input_param)))
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
#param.copy_(input_param)
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + "_extra_state"
if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
@contextlib.contextmanager @contextlib.contextmanager
def use_lazy_load( def use_lazy_load(
@@ -453,8 +524,30 @@ def use_lazy_load(
patch_safetensors(callback) patch_safetensors(callback)
if dematerialized_modules: if dematerialized_modules:
init_empty_weights = accelerate.init_empty_weights() # Most devices can just use Accelerate's implementation, but the Transformers on
init_empty_weights.__enter__() # the TPU complains about emptied weights unless we use VE's custom patches
if not USE_TPU_EMPTY_MODULE_METHOD:
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
with use_custom_unpickler(_LazyUnpickler): with use_custom_unpickler(_LazyUnpickler):
yield True yield True
@@ -469,7 +562,13 @@ def use_lazy_load(
) )
if dematerialized_modules: if dematerialized_modules:
init_empty_weights.__exit__(None, None, None) if not USE_TPU_EMPTY_MODULE_METHOD:
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
def post_load_cleanup() -> None: def post_load_cleanup() -> None:

View File

@@ -34,4 +34,3 @@ ijson
ftfy ftfy
pydub pydub
sentencepiece sentencepiece
accelerate==0.18.0