mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Work on lazyload
This commit is contained in:
@@ -18,8 +18,8 @@ from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
|||||||
|
|
||||||
model_backend_name = "Huggingface"
|
model_backend_name = "Huggingface"
|
||||||
|
|
||||||
class model_backend(HFTorchInferenceModel):
|
|
||||||
|
|
||||||
|
class model_backend(HFTorchInferenceModel):
|
||||||
def _initialize_model(self):
|
def _initialize_model(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -34,9 +34,7 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
|
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
|
||||||
|
|
||||||
if self.model_name == "NeoCustom":
|
if self.model_name == "NeoCustom":
|
||||||
self.model_name = os.path.basename(
|
self.model_name = os.path.basename(os.path.normpath(self.path))
|
||||||
os.path.normpath(self.path)
|
|
||||||
)
|
|
||||||
utils.koboldai_vars.model = self.model_name
|
utils.koboldai_vars.model = self.model_name
|
||||||
|
|
||||||
# If we specify a model and it's in the root directory, we need to move
|
# If we specify a model and it's in the root directory, we need to move
|
||||||
@@ -63,13 +61,18 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
|
|
||||||
# If we're using torch_lazy_loader, we need to get breakmodel config
|
# If we're using torch_lazy_loader, we need to get breakmodel config
|
||||||
# early so that it knows where to load the individual model tensors
|
# early so that it knows where to load the individual model tensors
|
||||||
logger.debug("lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format(self.lazy_load, utils.koboldai_vars.hascuda, self.breakmodel, self.nobreakmodel))
|
logger.debug(
|
||||||
|
"lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format(
|
||||||
|
self.lazy_load,
|
||||||
|
utils.koboldai_vars.hascuda,
|
||||||
|
self.breakmodel,
|
||||||
|
self.nobreakmodel,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.lazy_load:
|
if self.lazy_load:
|
||||||
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||||
with lazy_loader.use_lazy_load(
|
with lazy_loader.use_lazy_load(dematerialized_modules=True):
|
||||||
dematerialized_modules=True, use_accelerate_init_empty_weights=True
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||||
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||||
@@ -195,7 +198,9 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if not any_success:
|
if not any_success:
|
||||||
raise RuntimeError(f"Couldn't find any of {possible_checkpoint_names} in cache for {self.model_name} @ '{utils.koboldai_vars.revisison}'")
|
raise RuntimeError(
|
||||||
|
f"Couldn't find any of {possible_checkpoint_names} in cache for {self.model_name} @ '{utils.koboldai_vars.revisison}'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Handle saving sharded models
|
# Handle saving sharded models
|
||||||
|
|
||||||
@@ -234,10 +239,23 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
|
|
||||||
self.patch_embedding()
|
self.patch_embedding()
|
||||||
|
|
||||||
|
|
||||||
self.model.kai_model = self
|
self.model.kai_model = self
|
||||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
|
||||||
def _save_settings(self):
|
def _save_settings(self):
|
||||||
with open("settings/{}.generic_hf_torch.model_backend.settings".format(self.model_name.replace("/", "_")), "w") as f:
|
with open(
|
||||||
json.dump({"layers": self.layers if 'layers' in vars(self) else [], "disk_layers": self.disk_layers if 'disk_layers' in vars(self) else 0}, f, indent="")
|
"settings/{}.generic_hf_torch.model_backend.settings".format(
|
||||||
|
self.model_name.replace("/", "_")
|
||||||
|
),
|
||||||
|
"w",
|
||||||
|
) as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"layers": self.layers if "layers" in vars(self) else [],
|
||||||
|
"disk_layers": self.disk_layers
|
||||||
|
if "disk_layers" in vars(self)
|
||||||
|
else 0,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
indent="",
|
||||||
|
)
|
||||||
|
@@ -61,6 +61,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
|||||||
# support it yet.
|
# support it yet.
|
||||||
try:
|
try:
|
||||||
import safetensors
|
import safetensors
|
||||||
|
|
||||||
HAS_SAFETENSORS = True
|
HAS_SAFETENSORS = True
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
HAS_SAFETENSORS = False
|
HAS_SAFETENSORS = False
|
||||||
@@ -68,9 +69,6 @@ except ModuleNotFoundError:
|
|||||||
import utils
|
import utils
|
||||||
|
|
||||||
|
|
||||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
|
||||||
|
|
||||||
|
|
||||||
STORAGE_TYPE_MAP = {
|
STORAGE_TYPE_MAP = {
|
||||||
torch.float64: torch.DoubleStorage,
|
torch.float64: torch.DoubleStorage,
|
||||||
torch.float32: torch.FloatStorage,
|
torch.float32: torch.FloatStorage,
|
||||||
@@ -84,6 +82,8 @@ STORAGE_TYPE_MAP = {
|
|||||||
torch.bfloat16: torch.BFloat16Storage,
|
torch.bfloat16: torch.BFloat16Storage,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Storage of zipfile handles for each shard
|
||||||
|
torch_checkpoint_file_handles = {}
|
||||||
|
|
||||||
class LazyTensor:
|
class LazyTensor:
|
||||||
pass
|
pass
|
||||||
@@ -101,7 +101,6 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
stride: Optional[Tuple[int, ...]] = None,
|
stride: Optional[Tuple[int, ...]] = None,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
backward_hooks: Any = None,
|
backward_hooks: Any = None,
|
||||||
file_handle: Any = None
|
|
||||||
):
|
):
|
||||||
self.storage_type = storage_type
|
self.storage_type = storage_type
|
||||||
self.key = key
|
self.key = key
|
||||||
@@ -112,7 +111,7 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.requires_grad = requires_grad
|
self.requires_grad = requires_grad
|
||||||
self.backward_hooks = backward_hooks
|
self.backward_hooks = backward_hooks
|
||||||
self.file_handle = file_handle
|
self.file_name = None
|
||||||
|
|
||||||
def __view(self, f: Callable):
|
def __view(self, f: Callable):
|
||||||
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
|
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
|
||||||
@@ -127,7 +126,29 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
no_grad=True,
|
no_grad=True,
|
||||||
filename="pytorch_model.bin",
|
filename="pytorch_model.bin",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
checkpoint = checkpoint or self.file_handle
|
|
||||||
|
|
||||||
|
|
||||||
|
# if f not in torch_tensor_container_file_map:
|
||||||
|
# torch_tensor_container_file_map[f] = []
|
||||||
|
|
||||||
|
# with zipfile.ZipFile(f, "r") as z:
|
||||||
|
# paths = z.namelist()
|
||||||
|
|
||||||
|
# for name in paths:
|
||||||
|
# val = name.split("/data/")[-1]
|
||||||
|
# if not val.isdecimal():
|
||||||
|
# continue
|
||||||
|
# torch_tensor_container_file_map[f].append(int(val))
|
||||||
|
# torch_tensor_container_file_map[f].sort()
|
||||||
|
# print(torch_tensor_container_file_map)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if not checkpoint:
|
||||||
|
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
||||||
|
filename = self.file_name
|
||||||
|
|
||||||
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
|
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
|
||||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||||
@@ -141,19 +162,23 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
>> 3
|
>> 3
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(checkpoint, zipfile.ZipFile):
|
if isinstance(checkpoint, zipfile.ZipFile):
|
||||||
try:
|
try:
|
||||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||||
except:
|
except:
|
||||||
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
|
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
|
||||||
f.read(self.seek_offset)
|
f.seek(self.seek_offset, os.SEEK_CUR)
|
||||||
|
# f.read(self.seek_offset)
|
||||||
else:
|
else:
|
||||||
f = checkpoint
|
f = checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
|
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
|
||||||
finally:
|
finally:
|
||||||
if isinstance(checkpoint, zipfile.ZipFile):
|
if isinstance(checkpoint, zipfile.ZipFile):
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
storage = torch.serialization._get_restore_location(map_location)(
|
storage = torch.serialization._get_restore_location(map_location)(
|
||||||
storage, self.location
|
storage, self.location
|
||||||
)
|
)
|
||||||
@@ -242,7 +267,6 @@ class _LazyUnpickler(RestrictedUnpickler):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# print(args, kwargs)
|
# print(args, kwargs)
|
||||||
self.file_handle = args[0]
|
|
||||||
self.lazy_loaded_storages = {}
|
self.lazy_loaded_storages = {}
|
||||||
return super().__init__(*args, **kwargs)
|
return super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@@ -253,7 +277,9 @@ class _LazyUnpickler(RestrictedUnpickler):
|
|||||||
typename == "storage"
|
typename == "storage"
|
||||||
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||||
storage_type, key, location, _ = saved_id[1:]
|
storage_type, key, location, _ = saved_id[1:]
|
||||||
return TorchLazyTensor(storage_type, key, location, file_handle=self.file_handle)
|
return TorchLazyTensor(
|
||||||
|
storage_type, key, location
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, *args, **kwargs):
|
def load(self, *args, **kwargs):
|
||||||
retval = super().load(*args, **kwargs)
|
retval = super().load(*args, **kwargs)
|
||||||
@@ -277,117 +303,6 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
|
|||||||
return lazy_storage
|
return lazy_storage
|
||||||
|
|
||||||
|
|
||||||
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
|
|
||||||
def _load_from_state_dict(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
prefix,
|
|
||||||
local_metadata,
|
|
||||||
strict,
|
|
||||||
missing_keys,
|
|
||||||
unexpected_keys,
|
|
||||||
error_msgs,
|
|
||||||
):
|
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
|
||||||
hook(
|
|
||||||
state_dict,
|
|
||||||
prefix,
|
|
||||||
local_metadata,
|
|
||||||
strict,
|
|
||||||
missing_keys,
|
|
||||||
unexpected_keys,
|
|
||||||
error_msgs,
|
|
||||||
)
|
|
||||||
|
|
||||||
persistent_buffers = {
|
|
||||||
k: v
|
|
||||||
for k, v in self._buffers.items()
|
|
||||||
if k not in self._non_persistent_buffers_set
|
|
||||||
}
|
|
||||||
local_name_params = itertools.chain(
|
|
||||||
self._parameters.items(), persistent_buffers.items()
|
|
||||||
)
|
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
||||||
|
|
||||||
for name, param in local_state.items():
|
|
||||||
key = prefix + name
|
|
||||||
if key in state_dict:
|
|
||||||
input_param = state_dict[key]
|
|
||||||
if not torch.overrides.is_tensor_like(input_param):
|
|
||||||
error_msgs.append(
|
|
||||||
'While copying the parameter named "{}", '
|
|
||||||
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
|
||||||
"received {}".format(key, type(input_param))
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# This is used to avoid copying uninitialized parameters into
|
|
||||||
# non-lazy modules, since they dont have the hook to do the checks
|
|
||||||
# in such case, it will error when accessing the .shape attribute.
|
|
||||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
|
||||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
||||||
if (
|
|
||||||
not is_param_lazy
|
|
||||||
and len(param.shape) == 0
|
|
||||||
and len(input_param.shape) == 1
|
|
||||||
):
|
|
||||||
input_param = input_param[0]
|
|
||||||
|
|
||||||
if not is_param_lazy and input_param.shape != param.shape:
|
|
||||||
# local shape should match the one in checkpoint
|
|
||||||
error_msgs.append(
|
|
||||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
|
||||||
"the shape in current model is {}.".format(
|
|
||||||
key, input_param.shape, param.shape
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
# param.copy_(input_param)
|
|
||||||
new_param = torch.nn.Parameter(
|
|
||||||
input_param, requires_grad=param.requires_grad
|
|
||||||
) # This line is new
|
|
||||||
if name in self._parameters: # This line is new
|
|
||||||
self._parameters[name] = new_param # This line is new
|
|
||||||
if name in persistent_buffers: # This line is new
|
|
||||||
self._buffers[name] = new_param # This line is new
|
|
||||||
except Exception as ex:
|
|
||||||
error_msgs.append(
|
|
||||||
'While copying the parameter named "{}", '
|
|
||||||
"whose dimensions in the model are {} and "
|
|
||||||
"whose dimensions in the checkpoint are {}, "
|
|
||||||
"an exception occurred : {}.".format(
|
|
||||||
key, param.size(), input_param.size(), ex.args
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif strict:
|
|
||||||
missing_keys.append(key)
|
|
||||||
|
|
||||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
||||||
if (
|
|
||||||
hasattr(Module, "set_extra_state")
|
|
||||||
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
|
|
||||||
is not Module.set_extra_state
|
|
||||||
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
|
||||||
if extra_state_key in state_dict:
|
|
||||||
self.set_extra_state(state_dict[extra_state_key])
|
|
||||||
elif strict:
|
|
||||||
missing_keys.append(extra_state_key)
|
|
||||||
elif strict and (extra_state_key in state_dict):
|
|
||||||
unexpected_keys.append(extra_state_key)
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if key.startswith(prefix) and key != extra_state_key:
|
|
||||||
input_name = key[len(prefix) :]
|
|
||||||
input_name = input_name.split(".", 1)[
|
|
||||||
0
|
|
||||||
] # get the name of param/buffer/child
|
|
||||||
if input_name not in self._modules and input_name not in local_state:
|
|
||||||
unexpected_keys.append(key)
|
|
||||||
|
|
||||||
|
|
||||||
def safetensors_load_tensor_independently(
|
def safetensors_load_tensor_independently(
|
||||||
checkpoint_file: str, tensor_key: str, device: Any
|
checkpoint_file: str, tensor_key: str, device: Any
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -418,7 +333,9 @@ def patch_safetensors(callback):
|
|||||||
tensors = {}
|
tensors = {}
|
||||||
|
|
||||||
with safetensors.safe_open(
|
with safetensors.safe_open(
|
||||||
checkpoint_file, framework="pt", device=intermediary_device,
|
checkpoint_file,
|
||||||
|
framework="pt",
|
||||||
|
device=intermediary_device,
|
||||||
) as f:
|
) as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
tensors[key] = None
|
tensors[key] = None
|
||||||
@@ -426,7 +343,9 @@ def patch_safetensors(callback):
|
|||||||
for key in tensors.keys():
|
for key in tensors.keys():
|
||||||
|
|
||||||
tensors[key] = SafetensorsLazyTensor(
|
tensors[key] = SafetensorsLazyTensor(
|
||||||
checkpoint_file=checkpoint_file, key=key, location=intermediary_device,
|
checkpoint_file=checkpoint_file,
|
||||||
|
key=key,
|
||||||
|
location=intermediary_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@@ -442,6 +361,13 @@ def patch_safetensors(callback):
|
|||||||
|
|
||||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||||
|
|
||||||
|
def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor):
|
||||||
|
with zipfile.ZipFile(file, "r") as z:
|
||||||
|
storage_key = lazy_tensor.key
|
||||||
|
ziproot = z.namelist()[0].split("/")[0]
|
||||||
|
f = z.open(f"{ziproot}/data/{storage_key}")
|
||||||
|
# TODO: Maybe some file seeking
|
||||||
|
return f
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||||
@@ -468,7 +394,6 @@ def use_lazy_load(
|
|||||||
enable=True,
|
enable=True,
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
dematerialized_modules=False,
|
dematerialized_modules=False,
|
||||||
use_accelerate_init_empty_weights=False,
|
|
||||||
):
|
):
|
||||||
if not enable:
|
if not enable:
|
||||||
with use_custom_unpickler(RestrictedUnpickler):
|
with use_custom_unpickler(RestrictedUnpickler):
|
||||||
@@ -483,22 +408,30 @@ def use_lazy_load(
|
|||||||
old_torch_load = torch.load
|
old_torch_load = torch.load
|
||||||
|
|
||||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||||
retval = old_torch_load(
|
model_dict = old_torch_load(
|
||||||
f=f,
|
f=f,
|
||||||
map_location=map_location,
|
map_location=map_location,
|
||||||
pickle_module=pickle_module,
|
pickle_module=pickle_module,
|
||||||
**pickle_load_args,
|
**pickle_load_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if f not in torch_checkpoint_file_handles:
|
||||||
|
torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r")
|
||||||
|
|
||||||
|
for k,v in model_dict.items():
|
||||||
|
v.file_name = f
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
retval,
|
model_dict,
|
||||||
f=f,
|
f=f,
|
||||||
map_location=map_location,
|
map_location=map_location,
|
||||||
pickle_module=pickle_module,
|
pickle_module=pickle_module,
|
||||||
is_safetensors=False,
|
is_safetensors=False,
|
||||||
**pickle_load_args,
|
**pickle_load_args,
|
||||||
)
|
)
|
||||||
return retval
|
|
||||||
|
return model_dict
|
||||||
|
|
||||||
torch.load = torch_load
|
torch.load = torch_load
|
||||||
|
|
||||||
@@ -506,30 +439,10 @@ def use_lazy_load(
|
|||||||
patch_safetensors(callback)
|
patch_safetensors(callback)
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
if use_accelerate_init_empty_weights:
|
|
||||||
import accelerate
|
import accelerate
|
||||||
|
|
||||||
init_empty_weights = accelerate.init_empty_weights()
|
init_empty_weights = accelerate.init_empty_weights()
|
||||||
init_empty_weights.__enter__()
|
init_empty_weights.__enter__()
|
||||||
else:
|
|
||||||
old_linear_init = torch.nn.Linear.__init__
|
|
||||||
old_embedding_init = torch.nn.Embedding.__init__
|
|
||||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
|
||||||
|
|
||||||
def linear_init(self, *args, device=None, **kwargs):
|
|
||||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
|
||||||
|
|
||||||
def embedding_init(self, *args, device=None, **kwargs):
|
|
||||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
|
||||||
|
|
||||||
def layernorm_init(self, *args, device=None, **kwargs):
|
|
||||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
|
||||||
|
|
||||||
torch.nn.Linear.__init__ = linear_init
|
|
||||||
torch.nn.Embedding.__init__ = embedding_init
|
|
||||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
|
||||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
|
||||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
|
||||||
|
|
||||||
with use_custom_unpickler(_LazyUnpickler):
|
with use_custom_unpickler(_LazyUnpickler):
|
||||||
yield True
|
yield True
|
||||||
@@ -538,10 +451,4 @@ def use_lazy_load(
|
|||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
if use_accelerate_init_empty_weights:
|
|
||||||
init_empty_weights.__exit__(None, None, None)
|
init_empty_weights.__exit__(None, None, None)
|
||||||
else:
|
|
||||||
torch.nn.Linear.__init__ = old_linear_init
|
|
||||||
torch.nn.Embedding.__init__ = old_embedding_init
|
|
||||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
|
||||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
|
||||||
|
@@ -190,7 +190,7 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
|
|
||||||
# BEGIN PATCH
|
# BEGIN PATCH
|
||||||
if isinstance(param, LazyTensor):
|
if isinstance(param, LazyTensor):
|
||||||
print("Materializing", param_name)
|
print(".", end="", flush=True)
|
||||||
param = param.materialize()
|
param = param.materialize()
|
||||||
# END PATCH
|
# END PATCH
|
||||||
|
|
||||||
|
@@ -855,7 +855,7 @@ class TrainerBase(abc.ABC):
|
|||||||
lazy_load_callback.nested = False
|
lazy_load_callback.nested = False
|
||||||
|
|
||||||
# Since we're using lazy loader, we need to figure out what the model's hidden layers are called
|
# Since we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||||
with lazy_loader.use_lazy_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
|
with lazy_loader.use_lazy_load(dematerialized_modules=True):
|
||||||
try:
|
try:
|
||||||
metamodel = AutoModelForCausalLM.from_config(model_config)
|
metamodel = AutoModelForCausalLM.from_config(model_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user