mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #386 from one-some/accelerate-offloading
Use VE's patched load_from_state_dict on TPU for loading empty weights
This commit is contained in:
@@ -56,7 +56,6 @@ import collections
|
|||||||
import _codecs
|
import _codecs
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||||
import accelerate
|
|
||||||
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.storage import UntypedStorage
|
from torch.storage import UntypedStorage
|
||||||
@@ -70,6 +69,12 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
HAS_SAFETENSORS = False
|
HAS_SAFETENSORS = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import accelerate
|
||||||
|
USE_TPU_EMPTY_MODULE_METHOD = False
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
USE_TPU_EMPTY_MODULE_METHOD = True
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
@@ -400,6 +405,72 @@ def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler
|
|||||||
pickle.Unpickler = old_unpickler
|
pickle.Unpickler = old_unpickler
|
||||||
pickle.load = old_pickle_load
|
pickle.load = old_pickle_load
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
for hook in self._load_state_dict_pre_hooks.values():
|
||||||
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||||
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||||
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
||||||
|
for name, param in local_state.items():
|
||||||
|
key = prefix + name
|
||||||
|
if key in state_dict:
|
||||||
|
input_param = state_dict[key]
|
||||||
|
if not torch.overrides.is_tensor_like(input_param):
|
||||||
|
error_msgs.append('While copying the parameter named "{}", '
|
||||||
|
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
||||||
|
'received {}'
|
||||||
|
.format(key, type(input_param)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This is used to avoid copying uninitialized parameters into
|
||||||
|
# non-lazy modules, since they dont have the hook to do the checks
|
||||||
|
# in such case, it will error when accessing the .shape attribute.
|
||||||
|
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||||
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||||
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||||
|
input_param = input_param[0]
|
||||||
|
|
||||||
|
if not is_param_lazy and input_param.shape != param.shape:
|
||||||
|
# local shape should match the one in checkpoint
|
||||||
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||||
|
'the shape in current model is {}.'
|
||||||
|
.format(key, input_param.shape, param.shape))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
#param.copy_(input_param)
|
||||||
|
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
|
||||||
|
if name in self._parameters: # This line is new
|
||||||
|
self._parameters[name] = new_param # This line is new
|
||||||
|
if name in persistent_buffers: # This line is new
|
||||||
|
self._buffers[name] = new_param # This line is new
|
||||||
|
except Exception as ex:
|
||||||
|
error_msgs.append('While copying the parameter named "{}", '
|
||||||
|
'whose dimensions in the model are {} and '
|
||||||
|
'whose dimensions in the checkpoint are {}, '
|
||||||
|
'an exception occurred : {}.'
|
||||||
|
.format(key, param.size(), input_param.size(), ex.args))
|
||||||
|
elif strict:
|
||||||
|
missing_keys.append(key)
|
||||||
|
|
||||||
|
extra_state_key = prefix + "_extra_state"
|
||||||
|
if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||||
|
if extra_state_key in state_dict:
|
||||||
|
self.set_extra_state(state_dict[extra_state_key])
|
||||||
|
elif strict:
|
||||||
|
missing_keys.append(extra_state_key)
|
||||||
|
elif strict and (extra_state_key in state_dict):
|
||||||
|
unexpected_keys.append(extra_state_key)
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if key.startswith(prefix) and key != extra_state_key:
|
||||||
|
input_name = key[len(prefix):]
|
||||||
|
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||||
|
if input_name not in self._modules and input_name not in local_state:
|
||||||
|
unexpected_keys.append(key)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_lazy_load(
|
def use_lazy_load(
|
||||||
@@ -453,8 +524,30 @@ def use_lazy_load(
|
|||||||
patch_safetensors(callback)
|
patch_safetensors(callback)
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
init_empty_weights = accelerate.init_empty_weights()
|
# Most devices can just use Accelerate's implementation, but the Transformers on
|
||||||
init_empty_weights.__enter__()
|
# the TPU complains about emptied weights unless we use VE's custom patches
|
||||||
|
if not USE_TPU_EMPTY_MODULE_METHOD:
|
||||||
|
init_empty_weights = accelerate.init_empty_weights()
|
||||||
|
init_empty_weights.__enter__()
|
||||||
|
else:
|
||||||
|
old_linear_init = torch.nn.Linear.__init__
|
||||||
|
old_embedding_init = torch.nn.Embedding.__init__
|
||||||
|
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||||
|
|
||||||
|
def linear_init(self, *args, device=None, **kwargs):
|
||||||
|
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
|
def embedding_init(self, *args, device=None, **kwargs):
|
||||||
|
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
|
def layernorm_init(self, *args, device=None, **kwargs):
|
||||||
|
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
|
torch.nn.Linear.__init__ = linear_init
|
||||||
|
torch.nn.Embedding.__init__ = embedding_init
|
||||||
|
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||||
|
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||||
|
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||||
|
|
||||||
with use_custom_unpickler(_LazyUnpickler):
|
with use_custom_unpickler(_LazyUnpickler):
|
||||||
yield True
|
yield True
|
||||||
@@ -469,7 +562,13 @@ def use_lazy_load(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
init_empty_weights.__exit__(None, None, None)
|
if not USE_TPU_EMPTY_MODULE_METHOD:
|
||||||
|
init_empty_weights.__exit__(None, None, None)
|
||||||
|
else:
|
||||||
|
torch.nn.Linear.__init__ = old_linear_init
|
||||||
|
torch.nn.Embedding.__init__ = old_embedding_init
|
||||||
|
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||||
|
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
||||||
|
|
||||||
|
|
||||||
def post_load_cleanup() -> None:
|
def post_load_cleanup() -> None:
|
||||||
|
@@ -34,4 +34,3 @@ ijson
|
|||||||
ftfy
|
ftfy
|
||||||
pydub
|
pydub
|
||||||
sentencepiece
|
sentencepiece
|
||||||
accelerate==0.18.0
|
|
Reference in New Issue
Block a user