mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Just use accelerate on tpu
This commit is contained in:
@@ -56,6 +56,7 @@ import collections
|
||||
import _codecs
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
import accelerate
|
||||
|
||||
from torch.nn import Module
|
||||
from torch.storage import UntypedStorage
|
||||
@@ -64,6 +65,7 @@ from torch.storage import UntypedStorage
|
||||
# support it yet.
|
||||
try:
|
||||
import safetensors
|
||||
|
||||
HAS_SAFETENSORS = True
|
||||
except ModuleNotFoundError:
|
||||
HAS_SAFETENSORS = False
|
||||
@@ -71,17 +73,6 @@ except ModuleNotFoundError:
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
# Accelerate is used to load with empty modules. TPU version doesn't come
|
||||
# packaged with it so we use an in-house solution in that case
|
||||
try:
|
||||
import accelerate
|
||||
HAS_ACCELERATE = True
|
||||
except ModuleNotFoundError:
|
||||
HAS_ACCELERATE = False
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
# Storage of zipfile handles for each shard
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
||||
@@ -331,116 +322,6 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
|
||||
)
|
||||
return lazy_storage
|
||||
|
||||
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
persistent_buffers = {
|
||||
k: v
|
||||
for k, v in self._buffers.items()
|
||||
if k not in self._non_persistent_buffers_set
|
||||
}
|
||||
local_name_params = itertools.chain(
|
||||
self._parameters.items(), persistent_buffers.items()
|
||||
)
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
||||
"received {}".format(key, type(input_param))
|
||||
)
|
||||
continue
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if (
|
||||
not is_param_lazy
|
||||
and len(param.shape) == 0
|
||||
and len(input_param.shape) == 1
|
||||
):
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(
|
||||
key, input_param.shape, param.shape
|
||||
)
|
||||
)
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# param.copy_(input_param)
|
||||
new_param = torch.nn.Parameter(
|
||||
input_param, requires_grad=param.requires_grad
|
||||
) # This line is new
|
||||
if name in self._parameters: # This line is new
|
||||
self._parameters[name] = new_param # This line is new
|
||||
if name in persistent_buffers: # This line is new
|
||||
self._buffers[name] = new_param # This line is new
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(
|
||||
key, param.size(), input_param.size(), ex.args
|
||||
)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if (
|
||||
hasattr(Module, "set_extra_state")
|
||||
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
|
||||
is not Module.set_extra_state
|
||||
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[
|
||||
0
|
||||
] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
|
||||
def safetensors_load_tensor_independently(
|
||||
checkpoint_file: str, tensor_key: str, device: Any
|
||||
@@ -572,29 +453,8 @@ def use_lazy_load(
|
||||
patch_safetensors(callback)
|
||||
|
||||
if dematerialized_modules:
|
||||
if HAS_ACCELERATE:
|
||||
init_empty_weights = accelerate.init_empty_weights()
|
||||
init_empty_weights.__enter__()
|
||||
else:
|
||||
# TPU doesn't use accelerate package
|
||||
old_linear_init = torch.nn.Linear.__init__
|
||||
old_embedding_init = torch.nn.Embedding.__init__
|
||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||
|
||||
def linear_init(self, *args, device=None, **kwargs):
|
||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
def embedding_init(self, *args, device=None, **kwargs):
|
||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
def layernorm_init(self, *args, device=None, **kwargs):
|
||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
torch.nn.Linear.__init__ = linear_init
|
||||
torch.nn.Embedding.__init__ = embedding_init
|
||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||
|
||||
with use_custom_unpickler(_LazyUnpickler):
|
||||
yield True
|
||||
@@ -609,13 +469,7 @@ def use_lazy_load(
|
||||
)
|
||||
|
||||
if dematerialized_modules:
|
||||
if HAS_ACCELERATE:
|
||||
init_empty_weights.__exit__(None, None, None)
|
||||
else:
|
||||
torch.nn.Linear.__init__ = old_linear_init
|
||||
torch.nn.Embedding.__init__ = old_embedding_init
|
||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
||||
|
||||
|
||||
def post_load_cleanup() -> None:
|
||||
|
@@ -34,3 +34,4 @@ ijson
|
||||
ftfy
|
||||
pydub
|
||||
sentencepiece
|
||||
accelerate==0.18.0
|
Reference in New Issue
Block a user