Just use accelerate on tpu

This commit is contained in:
somebody
2023-07-03 17:18:48 -05:00
parent 1bb2d2621c
commit 7f869a54d8
2 changed files with 6 additions and 151 deletions

View File

@@ -56,6 +56,7 @@ import collections
import _codecs import _codecs
import os import os
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
import accelerate
from torch.nn import Module from torch.nn import Module
from torch.storage import UntypedStorage from torch.storage import UntypedStorage
@@ -64,6 +65,7 @@ from torch.storage import UntypedStorage
# support it yet. # support it yet.
try: try:
import safetensors import safetensors
HAS_SAFETENSORS = True HAS_SAFETENSORS = True
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_SAFETENSORS = False HAS_SAFETENSORS = False
@@ -71,17 +73,6 @@ except ModuleNotFoundError:
import utils import utils
from logger import logger from logger import logger
# Accelerate is used to load with empty modules. TPU version doesn't come
# packaged with it so we use an in-house solution in that case
try:
import accelerate
HAS_ACCELERATE = True
except ModuleNotFoundError:
HAS_ACCELERATE = False
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
# Storage of zipfile handles for each shard # Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {} torch_checkpoint_file_handles = {}
@@ -331,116 +322,6 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
) )
return lazy_storage return lazy_storage
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
for hook in self._load_state_dict_pre_hooks.values():
hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
persistent_buffers = {
k: v
for k, v in self._buffers.items()
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if (
not is_param_lazy
and len(param.shape) == 0
and len(input_param.shape) == 1
):
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(
key, input_param.shape, param.shape
)
)
continue
try:
with torch.no_grad():
# param.copy_(input_param)
new_param = torch.nn.Parameter(
input_param, requires_grad=param.requires_grad
) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(
key, param.size(), input_param.size(), ex.args
)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
hasattr(Module, "set_extra_state")
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
is not Module.set_extra_state
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[
0
] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def safetensors_load_tensor_independently( def safetensors_load_tensor_independently(
checkpoint_file: str, tensor_key: str, device: Any checkpoint_file: str, tensor_key: str, device: Any
@@ -572,29 +453,8 @@ def use_lazy_load(
patch_safetensors(callback) patch_safetensors(callback)
if dematerialized_modules: if dematerialized_modules:
if HAS_ACCELERATE: init_empty_weights = accelerate.init_empty_weights()
init_empty_weights = accelerate.init_empty_weights() init_empty_weights.__enter__()
init_empty_weights.__enter__()
else:
# TPU doesn't use accelerate package
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
with use_custom_unpickler(_LazyUnpickler): with use_custom_unpickler(_LazyUnpickler):
yield True yield True
@@ -609,13 +469,7 @@ def use_lazy_load(
) )
if dematerialized_modules: if dematerialized_modules:
if HAS_ACCELERATE: init_empty_weights.__exit__(None, None, None)
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
def post_load_cleanup() -> None: def post_load_cleanup() -> None:

View File

@@ -34,3 +34,4 @@ ijson
ftfy ftfy
pydub pydub
sentencepiece sentencepiece
accelerate==0.18.0