Load empty modules without accelerate

This commit is contained in:
somebody
2023-07-03 17:07:18 -05:00
parent 686c3d1592
commit 31a3046a18

View File

@@ -46,6 +46,7 @@ POSSIBILITY OF SUCH DAMAGE.
import contextlib
from functools import reduce
import itertools
import time
import zipfile
import pickle
@@ -56,13 +57,13 @@ import _codecs
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch.nn import Module
from torch.storage import UntypedStorage
# Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet.
try:
import safetensors
HAS_SAFETENSORS = True
except ModuleNotFoundError:
HAS_SAFETENSORS = False
@@ -70,6 +71,16 @@ except ModuleNotFoundError:
import utils
from logger import logger
# Accelerate is used to load with empty modules. TPU version doesn't come
# packaged with it so we use an in-house solution in that case
try:
import accelerate
HAS_ACCELERATE = True
except ModuleNotFoundError:
HAS_ACCELERATE = False
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
# Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {}
@@ -145,7 +156,6 @@ class TorchLazyTensor(LazyTensor):
map_location=None,
no_grad=True,
) -> torch.Tensor:
checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
@@ -321,6 +331,116 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
)
return lazy_storage
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
for hook in self._load_state_dict_pre_hooks.values():
hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
persistent_buffers = {
k: v
for k, v in self._buffers.items()
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if (
not is_param_lazy
and len(param.shape) == 0
and len(input_param.shape) == 1
):
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(
key, input_param.shape, param.shape
)
)
continue
try:
with torch.no_grad():
# param.copy_(input_param)
new_param = torch.nn.Parameter(
input_param, requires_grad=param.requires_grad
) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(
key, param.size(), input_param.size(), ex.args
)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
hasattr(Module, "set_extra_state")
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
is not Module.set_extra_state
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[
0
] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def safetensors_load_tensor_independently(
checkpoint_file: str, tensor_key: str, device: Any
@@ -360,7 +480,6 @@ def patch_safetensors(callback):
tensors[key] = None
for key in tensors.keys():
tensors[key] = SafetensorsLazyTensor(
checkpoint_file=checkpoint_file,
key=key,
@@ -453,10 +572,29 @@ def use_lazy_load(
patch_safetensors(callback)
if dematerialized_modules:
import accelerate
if HAS_ACCELERATE:
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
# TPU doesn't use accelerate package
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
with use_custom_unpickler(_LazyUnpickler):
yield True
@@ -471,14 +609,22 @@ def use_lazy_load(
)
if dematerialized_modules:
init_empty_weights.__exit__(None, None, None)
if HAS_ACCELERATE:
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
def post_load_cleanup() -> None:
"""Close dangling file pointers and clear caches after the load is complete."""
global torch_checkpoint_file_handles
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
logger.debug(
f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}"
)
CheckpointChunkCache.clear(unload_model=True)
# Bar is initialized in