Load empty modules without accelerate

This commit is contained in:
somebody
2023-07-03 17:07:18 -05:00
parent 686c3d1592
commit 31a3046a18

View File

@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
import contextlib import contextlib
from functools import reduce from functools import reduce
import itertools
import time import time
import zipfile import zipfile
import pickle import pickle
@@ -56,13 +57,13 @@ import _codecs
import os import os
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch.nn import Module
from torch.storage import UntypedStorage from torch.storage import UntypedStorage
# Safetensors is a dependency for the local version, TPU/Colab doesn't # Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet. # support it yet.
try: try:
import safetensors import safetensors
HAS_SAFETENSORS = True HAS_SAFETENSORS = True
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_SAFETENSORS = False HAS_SAFETENSORS = False
@@ -70,6 +71,16 @@ except ModuleNotFoundError:
import utils import utils
from logger import logger from logger import logger
# Accelerate is used to load with empty modules. TPU version doesn't come
# packaged with it so we use an in-house solution in that case
try:
import accelerate
HAS_ACCELERATE = True
except ModuleNotFoundError:
HAS_ACCELERATE = False
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
# Storage of zipfile handles for each shard # Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {} torch_checkpoint_file_handles = {}
@@ -145,7 +156,6 @@ class TorchLazyTensor(LazyTensor):
map_location=None, map_location=None,
no_grad=True, no_grad=True,
) -> torch.Tensor: ) -> torch.Tensor:
checkpoint = torch_checkpoint_file_handles[self.file_name] checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0] filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
@@ -321,6 +331,116 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
) )
return lazy_storage return lazy_storage
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
for hook in self._load_state_dict_pre_hooks.values():
hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
persistent_buffers = {
k: v
for k, v in self._buffers.items()
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if (
not is_param_lazy
and len(param.shape) == 0
and len(input_param.shape) == 1
):
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(
key, input_param.shape, param.shape
)
)
continue
try:
with torch.no_grad():
# param.copy_(input_param)
new_param = torch.nn.Parameter(
input_param, requires_grad=param.requires_grad
) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(
key, param.size(), input_param.size(), ex.args
)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
hasattr(Module, "set_extra_state")
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
is not Module.set_extra_state
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[
0
] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def safetensors_load_tensor_independently( def safetensors_load_tensor_independently(
checkpoint_file: str, tensor_key: str, device: Any checkpoint_file: str, tensor_key: str, device: Any
@@ -360,7 +480,6 @@ def patch_safetensors(callback):
tensors[key] = None tensors[key] = None
for key in tensors.keys(): for key in tensors.keys():
tensors[key] = SafetensorsLazyTensor( tensors[key] = SafetensorsLazyTensor(
checkpoint_file=checkpoint_file, checkpoint_file=checkpoint_file,
key=key, key=key,
@@ -453,10 +572,29 @@ def use_lazy_load(
patch_safetensors(callback) patch_safetensors(callback)
if dematerialized_modules: if dematerialized_modules:
import accelerate if HAS_ACCELERATE:
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
# TPU doesn't use accelerate package
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
init_empty_weights = accelerate.init_empty_weights() def linear_init(self, *args, device=None, **kwargs):
init_empty_weights.__enter__() return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
with use_custom_unpickler(_LazyUnpickler): with use_custom_unpickler(_LazyUnpickler):
yield True yield True
@@ -471,14 +609,22 @@ def use_lazy_load(
) )
if dematerialized_modules: if dematerialized_modules:
init_empty_weights.__exit__(None, None, None) if HAS_ACCELERATE:
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
def post_load_cleanup() -> None: def post_load_cleanup() -> None:
"""Close dangling file pointers and clear caches after the load is complete.""" """Close dangling file pointers and clear caches after the load is complete."""
global torch_checkpoint_file_handles global torch_checkpoint_file_handles
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}") logger.debug(
f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}"
)
CheckpointChunkCache.clear(unload_model=True) CheckpointChunkCache.clear(unload_model=True)
# Bar is initialized in # Bar is initialized in