Lazyload: Safetensors

This commit is contained in:
somebody
2023-04-02 15:40:34 -05:00
parent 91bb433b5f
commit 9d70646e4d
6 changed files with 419 additions and 105 deletions

View File

@@ -172,9 +172,13 @@ class InferenceModel:
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
"""User-facing load function. Do not override this; try `_load()` instead."""
self._pre_load()
self._load(save_model=save_model, initial_load=initial_load)
self._post_load()
def _pre_load(self) -> None:
"""Pre load hook. Called before `_load()`."""
def _post_load(self) -> None:
"""Post load hook. Called after `_load()`."""

View File

@@ -9,8 +9,9 @@ from typing import Union
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
import utils
import torch_lazy_loader
import modeling.lazy_loader as lazy_loader
import koboldai_settings
from logger import logger, set_logger_verbosity, quiesce_logger
try:
import breakmodel
@@ -73,20 +74,20 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
if self.lazy_load:
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(
with lazy_loader.use_lazy_load(
dematerialized_modules=True, use_accelerate_init_empty_weights=True
):
try:
metamodel = AutoModelForCausalLM.from_config(self.model_config)
except Exception as e:
print("Fell back to neo for metamodel")
logger.error(f"Fell back to neo for metamodel due to {e}")
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
utils.layers_module_names = utils.get_layers_module_names(metamodel)
utils.module_names = list(metamodel.state_dict().keys())
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
# Download model from Huggingface if it does not exist, otherwise load locally
with self._maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(
with self._maybe_use_float16(), lazy_loader.use_lazy_load(
enable=self.lazy_load,
callback=self._get_lazy_load_callback(utils.num_layers(self.model_config))
if self.lazy_load
@@ -108,12 +109,12 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
old_rebuild_tensor = torch._utils._rebuild_tensor
def new_rebuild_tensor(
storage: Union[torch_lazy_loader.LazyTensor, torch.Storage],
storage: Union[lazy_loader.LazyTensor, torch.Storage],
storage_offset,
shape,
stride,
):
if not isinstance(storage, torch_lazy_loader.LazyTensor):
if not isinstance(storage, lazy_loader.LazyTensor):
dtype = storage.dtype
else:
dtype = storage.storage_type.dtype
@@ -256,8 +257,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
elif breakmodel.disk_blocks > 0:
# Use disk
self._move_to_devices()
elif breakmodel.disk_blocks > 0:
self._move_to_devices()
else:
# Use CPU
self.model = self.model.to("cpu").float()
@@ -265,5 +264,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
self._move_to_devices()
else:
self.model = self.model.to("cpu").float()
self.model.kai_model = self
utils.koboldai_vars.modeldim = self.get_hidden_size()

View File

@@ -20,15 +20,13 @@ from transformers import (
GPTNeoForCausalLM,
AutoModelForCausalLM,
LogitsProcessorList,
LogitsProcessor,
)
import utils
import torch_lazy_loader
import modeling.lazy_loader as lazy_loader
from logger import logger, Colors
from modeling import warpers
from modeling import inference_model
from modeling.warpers import Warper
from modeling.stoppers import Stoppers
from modeling.post_token_hooks import PostTokenHooks
@@ -274,7 +272,7 @@ class HFTorchInferenceModel(HFInferenceModel):
**tf_kwargs,
)
except Exception as e:
print("Fell back for model due to", e)
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
if "out of memory" in traceback.format_exc().lower():
raise RuntimeError(
@@ -291,6 +289,18 @@ class HFTorchInferenceModel(HFInferenceModel):
def get_hidden_size(self) -> int:
return self.model.get_input_embeddings().embedding_dim
def _will_load_with_safetensors(self) -> bool:
path = self.get_local_model_path()
# TODO: This might mess up download to run
if not path:
return False
if not os.path.exists(os.path.join(path, "model.safetensors")):
return False
return True
def _move_to_devices(self) -> None:
if not utils.koboldai_vars.breakmodel:
if utils.koboldai_vars.usegpu:
@@ -391,8 +401,9 @@ class HFTorchInferenceModel(HFInferenceModel):
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
def lazy_load_callback(
model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]],
model_dict: Dict[str, Union[lazy_loader.LazyTensor, torch.Tensor]],
f,
is_safetensors: bool = False,
**_,
):
if lazy_load_callback.nested:
@@ -414,7 +425,8 @@ class HFTorchInferenceModel(HFInferenceModel):
for key, value in model_dict.items():
original_key = get_original_key(key)
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(
if isinstance(value, lazy_loader.LazyTensor) and not any(
original_key.startswith(n) for n in utils.layers_module_names
):
device_map[key] = (
@@ -483,59 +495,173 @@ class HFTorchInferenceModel(HFInferenceModel):
file=utils.UIProgressBarFile(),
)
with zipfile.ZipFile(f, "r") as z:
if not is_safetensors:
# Torch lazyload
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
zipfolder = os.path.basename(os.path.normpath(f)).split(".")[0]
f = None
current_offset = 0
able_to_pin_layers = True
if utils.num_shards is not None:
utils.current_shard += 1
for key in sorted(
device_map.keys(),
key=lambda k: (
model_dict[k].key,
model_dict[k].seek_offset,
),
):
storage_key = model_dict[key].key
if (
storage_key != last_storage_key
or model_dict[key].seek_offset < current_offset
):
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
try:
f = z.open(f"archive/data/{storage_key}")
except:
f = z.open(f"{zipfolder}/data/{storage_key}")
current_offset = 0
if current_offset != model_dict[key].seek_offset:
f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
device = device_map[key]
size = functools.reduce(
lambda x, y: x * y, model_dict[key].shape, 1
)
dtype = model_dict[key].dtype
nbytes = (
size
if dtype is torch.bool
else size
* (
(
torch.finfo
if dtype.is_floating_point
else torch.iinfo
)(dtype).bits
>> 3
)
)
# print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(
f, map_location="cpu"
)
if model_dict[key].dtype is torch.float32:
utils.koboldai_vars.fp32_model = True
if (
convert_to_float16
and breakmodel.primary_device != "cpu"
and utils.koboldai_vars.hascuda
and (
utils.koboldai_vars.breakmodel
or utils.koboldai_vars.usegpu
)
and model_dict[key].dtype is torch.float32
):
model_dict[key] = model_dict[key].to(torch.float16)
if breakmodel.primary_device == "cpu" or (
not utils.koboldai_vars.usegpu
and not utils.koboldai_vars.breakmodel
and model_dict[key].dtype is torch.float16
):
model_dict[key] = model_dict[key].to(torch.float32)
if device == "shared":
model_dict[key] = model_dict[key].to("cpu").detach_()
if able_to_pin_layers:
try:
model_dict[key] = model_dict[key].pin_memory()
except:
able_to_pin_layers = False
elif device == "disk":
accelerate.utils.offload_weight(
model_dict[key],
get_original_key(key),
"accelerate-disk-cache",
index=utils.offload_index,
)
model_dict[key] = model_dict[key].to("meta")
else:
model_dict[key] = model_dict[key].to(device)
# print("OK", flush=True)
current_offset += nbytes
utils.bar.update(1)
utils.koboldai_vars.loaded_layers += 1
finally:
if (
utils.num_shards is None
or utils.current_shard >= utils.num_shards
):
if utils.offload_index:
for name, tensor in utils.named_buffers:
dtype = tensor.dtype
if (
convert_to_float16
and breakmodel.primary_device != "cpu"
and utils.koboldai_vars.hascuda
and (
utils.koboldai_vars.breakmodel
or utils.koboldai_vars.usegpu
)
):
dtype = torch.float16
if breakmodel.primary_device == "cpu" or (
not utils.koboldai_vars.usegpu
and not utils.koboldai_vars.breakmodel
):
dtype = torch.float32
if (
name in model_dict
and model_dict[name].dtype is not dtype
):
model_dict[name] = model_dict[name].to(dtype)
if tensor.dtype is not dtype:
tensor = tensor.to(dtype)
if name not in utils.offload_index:
accelerate.utils.offload_weight(
tensor,
name,
"accelerate-disk-cache",
index=utils.offload_index,
)
accelerate.utils.save_offload_index(
utils.offload_index, "accelerate-disk-cache"
)
utils.bar.close()
utils.bar = None
utils.koboldai_vars.status_message = ""
lazy_load_callback.nested = False
if isinstance(f, zipfile.ZipExtFile):
f.close()
else:
# Loading with safetensors
try:
last_storage_key = None
zipfolder = os.path.basename(os.path.normpath(f)).split(".")[0]
f = None
current_offset = 0
able_to_pin_layers = True
if utils.num_shards is not None:
utils.current_shard += 1
for key in sorted(
device_map.keys(),
key=lambda k: (model_dict[k].key, model_dict[k].seek_offset),
key=lambda k: model_dict[k].key,
):
storage_key = model_dict[key].key
if (
storage_key != last_storage_key
or model_dict[key].seek_offset < current_offset
):
last_storage_key = storage_key
if isinstance(f, zipfile.ZipExtFile):
f.close()
try:
f = z.open(f"archive/data/{storage_key}")
except:
f = z.open(f"{zipfolder}/data/{storage_key}")
current_offset = 0
if current_offset != model_dict[key].seek_offset:
f.read(model_dict[key].seek_offset - current_offset)
current_offset = model_dict[key].seek_offset
device = device_map[key]
size = functools.reduce(
lambda x, y: x * y, model_dict[key].shape, 1
)
dtype = model_dict[key].dtype
nbytes = (
size
if dtype is torch.bool
else size
* (
(
torch.finfo
if dtype.is_floating_point
else torch.iinfo
)(dtype).bits
>> 3
)
)
# print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(
f, map_location="cpu"
)
if model_dict[key].dtype is torch.float32:
utils.koboldai_vars.fp32_model = True
if (
convert_to_float16
and breakmodel.primary_device != "cpu"
@@ -547,12 +673,14 @@ class HFTorchInferenceModel(HFInferenceModel):
and model_dict[key].dtype is torch.float32
):
model_dict[key] = model_dict[key].to(torch.float16)
if breakmodel.primary_device == "cpu" or (
not utils.koboldai_vars.usegpu
and not utils.koboldai_vars.breakmodel
and model_dict[key].dtype is torch.float16
):
model_dict[key] = model_dict[key].to(torch.float32)
if device == "shared":
model_dict[key] = model_dict[key].to("cpu").detach_()
if able_to_pin_layers:
@@ -570,10 +698,10 @@ class HFTorchInferenceModel(HFInferenceModel):
model_dict[key] = model_dict[key].to("meta")
else:
model_dict[key] = model_dict[key].to(device)
# print("OK", flush=True)
current_offset += nbytes
utils.bar.update(1)
utils.koboldai_vars.loaded_layers += 1
finally:
if (
utils.num_shards is None
@@ -617,9 +745,8 @@ class HFTorchInferenceModel(HFInferenceModel):
utils.bar.close()
utils.bar = None
utils.koboldai_vars.status_message = ""
lazy_load_callback.nested = False
if isinstance(f, zipfile.ZipExtFile):
f.close()
lazy_load_callback.nested = False
return lazy_load_callback

View File

@@ -1,4 +1,4 @@
'''
"""
This file is AGPL-licensed.
Some of the code in this file is copied from PyTorch.
@@ -41,7 +41,7 @@ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
'''
"""
import contextlib
@@ -53,13 +53,15 @@ import torch
import numpy as np
import collections
import _codecs
import utils
import os
import safetensors
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import utils
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
STORAGE_TYPE_MAP = {
@@ -77,7 +79,22 @@ STORAGE_TYPE_MAP = {
class LazyTensor:
def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
pass
class TorchLazyTensor(LazyTensor):
def __init__(
self,
storage_type,
key: str,
location: str,
dtype: Optional[torch.dtype] = None,
seek_offset: Optional[int] = None,
shape: Optional[Tuple[int, ...]] = None,
stride: Optional[Tuple[int, ...]] = None,
requires_grad=False,
backward_hooks: Any = None,
):
self.storage_type = storage_type
self.key = key
self.location = location
@@ -94,11 +111,25 @@ class LazyTensor:
def __repr__(self):
return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
def materialize(
self,
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile],
map_location=None,
no_grad=True,
filename="pytorch_model.bin",
) -> torch.Tensor:
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
nbytes = (
size
if dtype is torch.bool
else size
* (
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits
>> 3
)
)
if isinstance(checkpoint, zipfile.ZipFile):
try:
f = checkpoint.open(f"archive/data/{self.key}", "r")
@@ -112,13 +143,38 @@ class LazyTensor:
finally:
if isinstance(checkpoint, zipfile.ZipFile):
f.close()
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
storage = torch.serialization._get_restore_location(map_location)(
storage, self.location
)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride)
tensor.requires_grad = not no_grad and self.requires_grad
tensor._backward_hooks = self.backward_hooks
return tensor
class SafetensorsLazyTensor(LazyTensor):
def __init__(self, checkpoint_file: str, key: str, location: str):
self.checkpoint_file = checkpoint_file
self.key = key
self.location = location
def __view(self, f: Callable):
return f"{type(self).__name__}(checkpoint_file={f(self.checkpoint_file)}, key={f(self.key)}, location={f(self.location)})"
def __repr__(self):
return self.__view(repr)
def materialize(
self,
*args,
**kwargs,
) -> torch.Tensor:
return safetensors_load_tensor_independently(
self.checkpoint_file, tensor_key=self.key, device=self.location
)
class RestrictedUnpickler(pickle.Unpickler):
def original_persistent_load(self, saved_id):
return super().persistent_load(saved_id)
@@ -155,13 +211,18 @@ class RestrictedUnpickler(pickle.Unpickler):
else:
# Forbid everything else.
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
raise pickle.UnpicklingError(
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code"
)
def load(self, *args, **kwargs):
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
self.original_persistent_load = getattr(
self, "persistent_load", pickle.Unpickler.persistent_load
)
self.persistent_load = self.forced_persistent_load
return super().load(*args, **kwargs)
class _LazyUnpickler(RestrictedUnpickler):
lazy_loaded_storages: Dict[str, LazyTensor]
@@ -172,9 +233,11 @@ class _LazyUnpickler(RestrictedUnpickler):
def forced_persistent_load(self, saved_id):
assert isinstance(saved_id, tuple)
typename = saved_id[0]
assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
assert (
typename == "storage"
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, _ = saved_id[1:]
return LazyTensor(storage_type, key, location)
return TorchLazyTensor(storage_type, key, location)
def load(self, *args, **kwargs):
retval = super().load(*args, **kwargs)
@@ -189,17 +252,45 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
if not isinstance(dtype, torch.dtype):
dtype = lazy_storage.storage_type(0).dtype
lazy_storage.dtype = dtype
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
lazy_storage.seek_offset = (
storage_offset
if dtype is torch.bool
else storage_offset
* ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
)
return lazy_storage
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
persistent_buffers = {
k: v
for k, v in self._buffers.items()
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
@@ -207,10 +298,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.format(key, type(input_param)))
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
# This is used to avoid copying uninitialized parameters into
@@ -218,34 +310,50 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
if (
not is_param_lazy
and len(param.shape) == 0
and len(input_param.shape) == 1
):
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(
key, input_param.shape, param.shape
)
)
continue
try:
with torch.no_grad():
#param.copy_(input_param)
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
# param.copy_(input_param)
new_param = torch.nn.Parameter(
input_param, requires_grad=param.requires_grad
) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(
key, param.size(), input_param.size(), ex.args
)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if (
hasattr(Module, "set_extra_state")
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
is not Module.set_extra_state
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
@@ -256,12 +364,24 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[
0
] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def safetensors_load_tensor_independently(
checkpoint_file: str, tensor_key: str, device: Any
) -> torch.Tensor:
"""A hacky way to load a tensor by itself and not mmap every single tensor
or whatever is causing that big memory spike"""
with safetensors.safe_open(checkpoint_file, framework="pt", device=device) as f:
return f.get_tensor(tensor_key)
@contextlib.contextmanager
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
try:
@@ -281,8 +401,14 @@ def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler
pickle.Unpickler = old_unpickler
pickle.load = old_pickle_load
@contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
def use_lazy_load(
enable=True,
callback: Optional[Callable] = None,
dematerialized_modules=False,
use_accelerate_init_empty_weights=False,
):
if not enable:
with use_custom_unpickler(RestrictedUnpickler):
yield False
@@ -292,19 +418,76 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor
# Torch load patch
old_torch_load = torch.load
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
retval = old_torch_load(
f=f,
map_location=map_location,
pickle_module=pickle_module,
**pickle_load_args,
)
if callback is not None:
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
callback(
retval,
f=f,
map_location=map_location,
pickle_module=pickle_module,
is_safetensors=False,
**pickle_load_args,
)
return retval
torch.load = torch_load
# Safetensors load patch
import transformers
def safetensors_load(checkpoint_file: str) -> dict:
# Monkeypatch applied to safetensors.torch.load_file
if utils.koboldai_vars.hascuda:
# Use GPU as intermediary whenever possible, lowers RAM usage
# by a significant amount while making loading slightly slower
# (70 tensors/s -> 65 tensor/s). The memory savings probably
# shouldn't be the happening, maybe there's a memory leak
# somewhere in our pipeline with CPU tensors.
intermediary_device = "cuda"
else:
intermediary_device = "cpu"
tensors = {}
with safetensors.safe_open(
checkpoint_file, framework="pt", device=intermediary_device,
) as f:
for key in f.keys():
tensors[key] = None
for key in tensors.keys():
tensors[key] = SafetensorsLazyTensor(
checkpoint_file=checkpoint_file, key=key, location=intermediary_device,
)
if callback is not None:
callback(
tensors,
f=checkpoint_file,
map_location=None,
pickle_module=pickle,
is_safetensors=True,
)
return tensors
transformers.modeling_utils.safe_load_file = safetensors_load
if dematerialized_modules:
if use_accelerate_init_empty_weights:
import accelerate
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:

View File

@@ -38,7 +38,7 @@ import logging
logging.getLogger("urllib3").setLevel(logging.ERROR)
import breakmodel
import torch_lazy_loader
import modeling.lazy_loader as lazy_loader
import utils
use_breakmodel = True
@@ -755,7 +755,7 @@ class TrainerBase(abc.ABC):
device_list(ram_blocks, primary=breakmodel.primary_device)
def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
def lazy_load_callback(model_dict: Dict[str, Union[lazy_loader.LazyTensor, torch.Tensor]], f, **_):
if lazy_load_callback.nested:
return
lazy_load_callback.nested = True
@@ -768,7 +768,7 @@ class TrainerBase(abc.ABC):
for key, value in model_dict.items():
original_key = get_original_key(key)
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(original_key.startswith(n) for n in utils.layers_module_names):
if isinstance(value, lazy_loader.LazyTensor) and not any(original_key.startswith(n) for n in utils.layers_module_names):
device_map[key] = gpu_device if hascuda and usegpu else "cpu" if not hascuda or not use_breakmodel else breakmodel.primary_device
else:
layer = int(max((n for n in utils.layers_module_names if original_key.startswith(n)), key=len).rsplit(".", 1)[1])
@@ -855,7 +855,7 @@ class TrainerBase(abc.ABC):
lazy_load_callback.nested = False
# Since we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
with lazy_loader.use_lazy_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
try:
metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e:
@@ -864,7 +864,7 @@ class TrainerBase(abc.ABC):
utils.module_names = list(metamodel.state_dict().keys())
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
with torch_lazy_loader.use_lazy_torch_load(callback=lazy_load_callback, dematerialized_modules=True):
with lazy_loader.use_lazy_load(callback=lazy_load_callback, dematerialized_modules=True):
if(os.path.isdir(self.data.ckpt_path)):
try:
model = AutoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")

View File

@@ -674,7 +674,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
import torch
import torch.utils.dlpack
import torch_lazy_loader
import modeling.lazy_loader as lazy_loader
from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap(
@@ -722,7 +722,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
continue
layer = checkpoint_layer - 2
shards = []
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
with lazy_loader.use_custom_unpickler(lazy_loader.RestrictedUnpickler):
for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]:
@@ -997,7 +997,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
model_spec[key] = spec["mtj"].copy()
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
import torch_lazy_loader
import modeling.lazy_loader as lazy_loader
import torch
from tqdm.auto import tqdm
import functools
@@ -1061,7 +1061,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
current_offset = model_dict[key].seek_offset
spec = model_spec[model_spec_key]
transforms = set(spec.get("transforms", ()))
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
if not isinstance(model_dict[key], lazy_loader.LazyTensor):
error = f"Duplicate key {repr(key)}"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
@@ -1141,7 +1141,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import shutil
shutil.move(koboldai_vars.model.replace('/', '_'), "models/{}".format(koboldai_vars.model.replace('/', '_')))
print("\n", flush=True)
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
with lazy_loader.use_lazy_load(callback=callback, dematerialized_modules=True):
if(os.path.isdir(koboldai_vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", use_fast=False)