mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Lazyload: Safetensors
This commit is contained in:
@@ -172,9 +172,13 @@ class InferenceModel:
|
||||
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
|
||||
"""User-facing load function. Do not override this; try `_load()` instead."""
|
||||
|
||||
self._pre_load()
|
||||
self._load(save_model=save_model, initial_load=initial_load)
|
||||
self._post_load()
|
||||
|
||||
def _pre_load(self) -> None:
|
||||
"""Pre load hook. Called before `_load()`."""
|
||||
|
||||
def _post_load(self) -> None:
|
||||
"""Post load hook. Called after `_load()`."""
|
||||
|
||||
|
@@ -9,8 +9,9 @@ from typing import Union
|
||||
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
|
||||
import utils
|
||||
import torch_lazy_loader
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
import koboldai_settings
|
||||
from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
|
||||
try:
|
||||
import breakmodel
|
||||
@@ -73,20 +74,20 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
|
||||
if self.lazy_load:
|
||||
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||
with torch_lazy_loader.use_lazy_torch_load(
|
||||
with lazy_loader.use_lazy_load(
|
||||
dematerialized_modules=True, use_accelerate_init_empty_weights=True
|
||||
):
|
||||
try:
|
||||
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||
except Exception as e:
|
||||
print("Fell back to neo for metamodel")
|
||||
logger.error(f"Fell back to neo for metamodel due to {e}")
|
||||
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
|
||||
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||
utils.module_names = list(metamodel.state_dict().keys())
|
||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
with self._maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(
|
||||
with self._maybe_use_float16(), lazy_loader.use_lazy_load(
|
||||
enable=self.lazy_load,
|
||||
callback=self._get_lazy_load_callback(utils.num_layers(self.model_config))
|
||||
if self.lazy_load
|
||||
@@ -108,12 +109,12 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
|
||||
def new_rebuild_tensor(
|
||||
storage: Union[torch_lazy_loader.LazyTensor, torch.Storage],
|
||||
storage: Union[lazy_loader.LazyTensor, torch.Storage],
|
||||
storage_offset,
|
||||
shape,
|
||||
stride,
|
||||
):
|
||||
if not isinstance(storage, torch_lazy_loader.LazyTensor):
|
||||
if not isinstance(storage, lazy_loader.LazyTensor):
|
||||
dtype = storage.dtype
|
||||
else:
|
||||
dtype = storage.storage_type.dtype
|
||||
@@ -256,8 +257,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
elif breakmodel.disk_blocks > 0:
|
||||
# Use disk
|
||||
self._move_to_devices()
|
||||
elif breakmodel.disk_blocks > 0:
|
||||
self._move_to_devices()
|
||||
else:
|
||||
# Use CPU
|
||||
self.model = self.model.to("cpu").float()
|
||||
@@ -265,5 +264,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
self._move_to_devices()
|
||||
else:
|
||||
self.model = self.model.to("cpu").float()
|
||||
|
||||
self.model.kai_model = self
|
||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||
|
@@ -20,15 +20,13 @@ from transformers import (
|
||||
GPTNeoForCausalLM,
|
||||
AutoModelForCausalLM,
|
||||
LogitsProcessorList,
|
||||
LogitsProcessor,
|
||||
)
|
||||
|
||||
import utils
|
||||
import torch_lazy_loader
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
from logger import logger, Colors
|
||||
|
||||
from modeling import warpers
|
||||
from modeling import inference_model
|
||||
from modeling.warpers import Warper
|
||||
from modeling.stoppers import Stoppers
|
||||
from modeling.post_token_hooks import PostTokenHooks
|
||||
@@ -274,7 +272,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
**tf_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Fell back for model due to", e)
|
||||
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
||||
|
||||
if "out of memory" in traceback.format_exc().lower():
|
||||
raise RuntimeError(
|
||||
@@ -291,6 +289,18 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.model.get_input_embeddings().embedding_dim
|
||||
|
||||
def _will_load_with_safetensors(self) -> bool:
|
||||
path = self.get_local_model_path()
|
||||
|
||||
# TODO: This might mess up download to run
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if not os.path.exists(os.path.join(path, "model.safetensors")):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _move_to_devices(self) -> None:
|
||||
if not utils.koboldai_vars.breakmodel:
|
||||
if utils.koboldai_vars.usegpu:
|
||||
@@ -391,8 +401,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
|
||||
def lazy_load_callback(
|
||||
model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]],
|
||||
model_dict: Dict[str, Union[lazy_loader.LazyTensor, torch.Tensor]],
|
||||
f,
|
||||
is_safetensors: bool = False,
|
||||
**_,
|
||||
):
|
||||
if lazy_load_callback.nested:
|
||||
@@ -414,7 +425,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
for key, value in model_dict.items():
|
||||
original_key = get_original_key(key)
|
||||
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(
|
||||
|
||||
if isinstance(value, lazy_loader.LazyTensor) and not any(
|
||||
original_key.startswith(n) for n in utils.layers_module_names
|
||||
):
|
||||
device_map[key] = (
|
||||
@@ -483,59 +495,173 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
file=utils.UIProgressBarFile(),
|
||||
)
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
if not is_safetensors:
|
||||
# Torch lazyload
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
zipfolder = os.path.basename(os.path.normpath(f)).split(".")[0]
|
||||
f = None
|
||||
current_offset = 0
|
||||
able_to_pin_layers = True
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in sorted(
|
||||
device_map.keys(),
|
||||
key=lambda k: (
|
||||
model_dict[k].key,
|
||||
model_dict[k].seek_offset,
|
||||
),
|
||||
):
|
||||
storage_key = model_dict[key].key
|
||||
if (
|
||||
storage_key != last_storage_key
|
||||
or model_dict[key].seek_offset < current_offset
|
||||
):
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
try:
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
except:
|
||||
f = z.open(f"{zipfolder}/data/{storage_key}")
|
||||
current_offset = 0
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.read(model_dict[key].seek_offset - current_offset)
|
||||
current_offset = model_dict[key].seek_offset
|
||||
device = device_map[key]
|
||||
size = functools.reduce(
|
||||
lambda x, y: x * y, model_dict[key].shape, 1
|
||||
)
|
||||
dtype = model_dict[key].dtype
|
||||
nbytes = (
|
||||
size
|
||||
if dtype is torch.bool
|
||||
else size
|
||||
* (
|
||||
(
|
||||
torch.finfo
|
||||
if dtype.is_floating_point
|
||||
else torch.iinfo
|
||||
)(dtype).bits
|
||||
>> 3
|
||||
)
|
||||
)
|
||||
# print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
model_dict[key] = model_dict[key].materialize(
|
||||
f, map_location="cpu"
|
||||
)
|
||||
if model_dict[key].dtype is torch.float32:
|
||||
utils.koboldai_vars.fp32_model = True
|
||||
if (
|
||||
convert_to_float16
|
||||
and breakmodel.primary_device != "cpu"
|
||||
and utils.koboldai_vars.hascuda
|
||||
and (
|
||||
utils.koboldai_vars.breakmodel
|
||||
or utils.koboldai_vars.usegpu
|
||||
)
|
||||
and model_dict[key].dtype is torch.float32
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float16)
|
||||
if breakmodel.primary_device == "cpu" or (
|
||||
not utils.koboldai_vars.usegpu
|
||||
and not utils.koboldai_vars.breakmodel
|
||||
and model_dict[key].dtype is torch.float16
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float32)
|
||||
if device == "shared":
|
||||
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||
if able_to_pin_layers:
|
||||
try:
|
||||
model_dict[key] = model_dict[key].pin_memory()
|
||||
except:
|
||||
able_to_pin_layers = False
|
||||
elif device == "disk":
|
||||
accelerate.utils.offload_weight(
|
||||
model_dict[key],
|
||||
get_original_key(key),
|
||||
"accelerate-disk-cache",
|
||||
index=utils.offload_index,
|
||||
)
|
||||
model_dict[key] = model_dict[key].to("meta")
|
||||
else:
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
# print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
utils.bar.update(1)
|
||||
utils.koboldai_vars.loaded_layers += 1
|
||||
finally:
|
||||
if (
|
||||
utils.num_shards is None
|
||||
or utils.current_shard >= utils.num_shards
|
||||
):
|
||||
if utils.offload_index:
|
||||
for name, tensor in utils.named_buffers:
|
||||
dtype = tensor.dtype
|
||||
if (
|
||||
convert_to_float16
|
||||
and breakmodel.primary_device != "cpu"
|
||||
and utils.koboldai_vars.hascuda
|
||||
and (
|
||||
utils.koboldai_vars.breakmodel
|
||||
or utils.koboldai_vars.usegpu
|
||||
)
|
||||
):
|
||||
dtype = torch.float16
|
||||
if breakmodel.primary_device == "cpu" or (
|
||||
not utils.koboldai_vars.usegpu
|
||||
and not utils.koboldai_vars.breakmodel
|
||||
):
|
||||
dtype = torch.float32
|
||||
if (
|
||||
name in model_dict
|
||||
and model_dict[name].dtype is not dtype
|
||||
):
|
||||
model_dict[name] = model_dict[name].to(dtype)
|
||||
if tensor.dtype is not dtype:
|
||||
tensor = tensor.to(dtype)
|
||||
if name not in utils.offload_index:
|
||||
accelerate.utils.offload_weight(
|
||||
tensor,
|
||||
name,
|
||||
"accelerate-disk-cache",
|
||||
index=utils.offload_index,
|
||||
)
|
||||
accelerate.utils.save_offload_index(
|
||||
utils.offload_index, "accelerate-disk-cache"
|
||||
)
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
utils.koboldai_vars.status_message = ""
|
||||
lazy_load_callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
else:
|
||||
# Loading with safetensors
|
||||
try:
|
||||
last_storage_key = None
|
||||
zipfolder = os.path.basename(os.path.normpath(f)).split(".")[0]
|
||||
f = None
|
||||
current_offset = 0
|
||||
able_to_pin_layers = True
|
||||
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
|
||||
for key in sorted(
|
||||
device_map.keys(),
|
||||
key=lambda k: (model_dict[k].key, model_dict[k].seek_offset),
|
||||
key=lambda k: model_dict[k].key,
|
||||
):
|
||||
storage_key = model_dict[key].key
|
||||
if (
|
||||
storage_key != last_storage_key
|
||||
or model_dict[key].seek_offset < current_offset
|
||||
):
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
try:
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
except:
|
||||
f = z.open(f"{zipfolder}/data/{storage_key}")
|
||||
current_offset = 0
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.read(model_dict[key].seek_offset - current_offset)
|
||||
current_offset = model_dict[key].seek_offset
|
||||
|
||||
device = device_map[key]
|
||||
size = functools.reduce(
|
||||
lambda x, y: x * y, model_dict[key].shape, 1
|
||||
)
|
||||
dtype = model_dict[key].dtype
|
||||
nbytes = (
|
||||
size
|
||||
if dtype is torch.bool
|
||||
else size
|
||||
* (
|
||||
(
|
||||
torch.finfo
|
||||
if dtype.is_floating_point
|
||||
else torch.iinfo
|
||||
)(dtype).bits
|
||||
>> 3
|
||||
)
|
||||
)
|
||||
|
||||
# print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
|
||||
model_dict[key] = model_dict[key].materialize(
|
||||
f, map_location="cpu"
|
||||
)
|
||||
|
||||
if model_dict[key].dtype is torch.float32:
|
||||
utils.koboldai_vars.fp32_model = True
|
||||
|
||||
if (
|
||||
convert_to_float16
|
||||
and breakmodel.primary_device != "cpu"
|
||||
@@ -547,12 +673,14 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
and model_dict[key].dtype is torch.float32
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float16)
|
||||
|
||||
if breakmodel.primary_device == "cpu" or (
|
||||
not utils.koboldai_vars.usegpu
|
||||
and not utils.koboldai_vars.breakmodel
|
||||
and model_dict[key].dtype is torch.float16
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float32)
|
||||
|
||||
if device == "shared":
|
||||
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||
if able_to_pin_layers:
|
||||
@@ -570,10 +698,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
model_dict[key] = model_dict[key].to("meta")
|
||||
else:
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
# print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
|
||||
utils.bar.update(1)
|
||||
utils.koboldai_vars.loaded_layers += 1
|
||||
|
||||
finally:
|
||||
if (
|
||||
utils.num_shards is None
|
||||
@@ -617,9 +745,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
utils.koboldai_vars.status_message = ""
|
||||
|
||||
lazy_load_callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
lazy_load_callback.nested = False
|
||||
return lazy_load_callback
|
||||
|
@@ -1,4 +1,4 @@
|
||||
'''
|
||||
"""
|
||||
This file is AGPL-licensed.
|
||||
|
||||
Some of the code in this file is copied from PyTorch.
|
||||
@@ -41,7 +41,7 @@ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
import contextlib
|
||||
@@ -53,13 +53,15 @@ import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
import _codecs
|
||||
import utils
|
||||
import os
|
||||
import safetensors
|
||||
from torch.nn import Module
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import utils
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
STORAGE_TYPE_MAP = {
|
||||
@@ -77,7 +79,22 @@ STORAGE_TYPE_MAP = {
|
||||
|
||||
|
||||
class LazyTensor:
|
||||
def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
|
||||
pass
|
||||
|
||||
|
||||
class TorchLazyTensor(LazyTensor):
|
||||
def __init__(
|
||||
self,
|
||||
storage_type,
|
||||
key: str,
|
||||
location: str,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
seek_offset: Optional[int] = None,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
stride: Optional[Tuple[int, ...]] = None,
|
||||
requires_grad=False,
|
||||
backward_hooks: Any = None,
|
||||
):
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
self.location = location
|
||||
@@ -94,11 +111,25 @@ class LazyTensor:
|
||||
def __repr__(self):
|
||||
return self.__view(repr)
|
||||
|
||||
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
|
||||
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
|
||||
def materialize(
|
||||
self,
|
||||
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile],
|
||||
map_location=None,
|
||||
no_grad=True,
|
||||
filename="pytorch_model.bin",
|
||||
) -> torch.Tensor:
|
||||
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
|
||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||
dtype = self.dtype
|
||||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
nbytes = (
|
||||
size
|
||||
if dtype is torch.bool
|
||||
else size
|
||||
* (
|
||||
(torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits
|
||||
>> 3
|
||||
)
|
||||
)
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
try:
|
||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||
@@ -112,13 +143,38 @@ class LazyTensor:
|
||||
finally:
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
f.close()
|
||||
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
|
||||
storage = torch.serialization._get_restore_location(map_location)(
|
||||
storage, self.location
|
||||
)
|
||||
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
||||
tensor.set_(storage, 0, self.shape, self.stride)
|
||||
tensor.requires_grad = not no_grad and self.requires_grad
|
||||
tensor._backward_hooks = self.backward_hooks
|
||||
return tensor
|
||||
|
||||
|
||||
class SafetensorsLazyTensor(LazyTensor):
|
||||
def __init__(self, checkpoint_file: str, key: str, location: str):
|
||||
self.checkpoint_file = checkpoint_file
|
||||
self.key = key
|
||||
self.location = location
|
||||
|
||||
def __view(self, f: Callable):
|
||||
return f"{type(self).__name__}(checkpoint_file={f(self.checkpoint_file)}, key={f(self.key)}, location={f(self.location)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__view(repr)
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return safetensors_load_tensor_independently(
|
||||
self.checkpoint_file, tensor_key=self.key, device=self.location
|
||||
)
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def original_persistent_load(self, saved_id):
|
||||
return super().persistent_load(saved_id)
|
||||
@@ -155,13 +211,18 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
else:
|
||||
# Forbid everything else.
|
||||
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
||||
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
|
||||
raise pickle.UnpicklingError(
|
||||
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code"
|
||||
)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
|
||||
self.original_persistent_load = getattr(
|
||||
self, "persistent_load", pickle.Unpickler.persistent_load
|
||||
)
|
||||
self.persistent_load = self.forced_persistent_load
|
||||
return super().load(*args, **kwargs)
|
||||
|
||||
|
||||
class _LazyUnpickler(RestrictedUnpickler):
|
||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||
|
||||
@@ -172,9 +233,11 @@ class _LazyUnpickler(RestrictedUnpickler):
|
||||
def forced_persistent_load(self, saved_id):
|
||||
assert isinstance(saved_id, tuple)
|
||||
typename = saved_id[0]
|
||||
assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
assert (
|
||||
typename == "storage"
|
||||
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, _ = saved_id[1:]
|
||||
return LazyTensor(storage_type, key, location)
|
||||
return TorchLazyTensor(storage_type, key, location)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
retval = super().load(*args, **kwargs)
|
||||
@@ -189,17 +252,45 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
|
||||
if not isinstance(dtype, torch.dtype):
|
||||
dtype = lazy_storage.storage_type(0).dtype
|
||||
lazy_storage.dtype = dtype
|
||||
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
lazy_storage.seek_offset = (
|
||||
storage_offset
|
||||
if dtype is torch.bool
|
||||
else storage_offset
|
||||
* ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
)
|
||||
return lazy_storage
|
||||
|
||||
|
||||
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
persistent_buffers = {
|
||||
k: v
|
||||
for k, v in self._buffers.items()
|
||||
if k not in self._non_persistent_buffers_set
|
||||
}
|
||||
local_name_params = itertools.chain(
|
||||
self._parameters.items(), persistent_buffers.items()
|
||||
)
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
@@ -207,10 +298,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
||||
'received {}'
|
||||
.format(key, type(input_param)))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
||||
"received {}".format(key, type(input_param))
|
||||
)
|
||||
continue
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
@@ -218,34 +310,50 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
if (
|
||||
not is_param_lazy
|
||||
and len(param.shape) == 0
|
||||
and len(input_param.shape) == 1
|
||||
):
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'
|
||||
.format(key, input_param.shape, param.shape))
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(
|
||||
key, input_param.shape, param.shape
|
||||
)
|
||||
)
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
#param.copy_(input_param)
|
||||
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
|
||||
# param.copy_(input_param)
|
||||
new_param = torch.nn.Parameter(
|
||||
input_param, requires_grad=param.requires_grad
|
||||
) # This line is new
|
||||
if name in self._parameters: # This line is new
|
||||
self._parameters[name] = new_param # This line is new
|
||||
if name in persistent_buffers: # This line is new
|
||||
self._buffers[name] = new_param # This line is new
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'
|
||||
.format(key, param.size(), input_param.size(), ex.args))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(
|
||||
key, param.size(), input_param.size(), ex.args
|
||||
)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if (
|
||||
hasattr(Module, "set_extra_state")
|
||||
and getattr(self.__class__, "set_extra_state", Module.set_extra_state)
|
||||
is not Module.set_extra_state
|
||||
): # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
@@ -256,12 +364,24 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[
|
||||
0
|
||||
] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
|
||||
def safetensors_load_tensor_independently(
|
||||
checkpoint_file: str, tensor_key: str, device: Any
|
||||
) -> torch.Tensor:
|
||||
"""A hacky way to load a tensor by itself and not mmap every single tensor
|
||||
or whatever is causing that big memory spike"""
|
||||
|
||||
with safetensors.safe_open(checkpoint_file, framework="pt", device=device) as f:
|
||||
return f.get_tensor(tensor_key)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||
try:
|
||||
@@ -281,8 +401,14 @@ def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler
|
||||
pickle.Unpickler = old_unpickler
|
||||
pickle.load = old_pickle_load
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
||||
def use_lazy_load(
|
||||
enable=True,
|
||||
callback: Optional[Callable] = None,
|
||||
dematerialized_modules=False,
|
||||
use_accelerate_init_empty_weights=False,
|
||||
):
|
||||
if not enable:
|
||||
with use_custom_unpickler(RestrictedUnpickler):
|
||||
yield False
|
||||
@@ -292,19 +418,76 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
|
||||
# Torch load patch
|
||||
old_torch_load = torch.load
|
||||
|
||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
retval = old_torch_load(
|
||||
f=f,
|
||||
map_location=map_location,
|
||||
pickle_module=pickle_module,
|
||||
**pickle_load_args,
|
||||
)
|
||||
if callback is not None:
|
||||
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
callback(
|
||||
retval,
|
||||
f=f,
|
||||
map_location=map_location,
|
||||
pickle_module=pickle_module,
|
||||
is_safetensors=False,
|
||||
**pickle_load_args,
|
||||
)
|
||||
return retval
|
||||
|
||||
torch.load = torch_load
|
||||
|
||||
# Safetensors load patch
|
||||
import transformers
|
||||
|
||||
def safetensors_load(checkpoint_file: str) -> dict:
|
||||
# Monkeypatch applied to safetensors.torch.load_file
|
||||
|
||||
if utils.koboldai_vars.hascuda:
|
||||
# Use GPU as intermediary whenever possible, lowers RAM usage
|
||||
# by a significant amount while making loading slightly slower
|
||||
# (70 tensors/s -> 65 tensor/s). The memory savings probably
|
||||
# shouldn't be the happening, maybe there's a memory leak
|
||||
# somewhere in our pipeline with CPU tensors.
|
||||
intermediary_device = "cuda"
|
||||
else:
|
||||
intermediary_device = "cpu"
|
||||
|
||||
tensors = {}
|
||||
|
||||
with safetensors.safe_open(
|
||||
checkpoint_file, framework="pt", device=intermediary_device,
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = None
|
||||
|
||||
for key in tensors.keys():
|
||||
|
||||
tensors[key] = SafetensorsLazyTensor(
|
||||
checkpoint_file=checkpoint_file, key=key, location=intermediary_device,
|
||||
)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
tensors,
|
||||
f=checkpoint_file,
|
||||
map_location=None,
|
||||
pickle_module=pickle,
|
||||
is_safetensors=True,
|
||||
)
|
||||
|
||||
return tensors
|
||||
|
||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||
|
||||
if dematerialized_modules:
|
||||
if use_accelerate_init_empty_weights:
|
||||
import accelerate
|
||||
|
||||
init_empty_weights = accelerate.init_empty_weights()
|
||||
init_empty_weights.__enter__()
|
||||
else:
|
@@ -38,7 +38,7 @@ import logging
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
import breakmodel
|
||||
import torch_lazy_loader
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
import utils
|
||||
|
||||
use_breakmodel = True
|
||||
@@ -755,7 +755,7 @@ class TrainerBase(abc.ABC):
|
||||
|
||||
device_list(ram_blocks, primary=breakmodel.primary_device)
|
||||
|
||||
def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
|
||||
def lazy_load_callback(model_dict: Dict[str, Union[lazy_loader.LazyTensor, torch.Tensor]], f, **_):
|
||||
if lazy_load_callback.nested:
|
||||
return
|
||||
lazy_load_callback.nested = True
|
||||
@@ -768,7 +768,7 @@ class TrainerBase(abc.ABC):
|
||||
|
||||
for key, value in model_dict.items():
|
||||
original_key = get_original_key(key)
|
||||
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(original_key.startswith(n) for n in utils.layers_module_names):
|
||||
if isinstance(value, lazy_loader.LazyTensor) and not any(original_key.startswith(n) for n in utils.layers_module_names):
|
||||
device_map[key] = gpu_device if hascuda and usegpu else "cpu" if not hascuda or not use_breakmodel else breakmodel.primary_device
|
||||
else:
|
||||
layer = int(max((n for n in utils.layers_module_names if original_key.startswith(n)), key=len).rsplit(".", 1)[1])
|
||||
@@ -855,7 +855,7 @@ class TrainerBase(abc.ABC):
|
||||
lazy_load_callback.nested = False
|
||||
|
||||
# Since we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
|
||||
with lazy_loader.use_lazy_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
|
||||
try:
|
||||
metamodel = AutoModelForCausalLM.from_config(model_config)
|
||||
except Exception as e:
|
||||
@@ -864,7 +864,7 @@ class TrainerBase(abc.ABC):
|
||||
utils.module_names = list(metamodel.state_dict().keys())
|
||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||
|
||||
with torch_lazy_loader.use_lazy_torch_load(callback=lazy_load_callback, dematerialized_modules=True):
|
||||
with lazy_loader.use_lazy_load(callback=lazy_load_callback, dematerialized_modules=True):
|
||||
if(os.path.isdir(self.data.ckpt_path)):
|
||||
try:
|
||||
model = AutoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
|
||||
|
@@ -674,7 +674,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
import torch_lazy_loader
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
move_xmap = jax.experimental.maps.xmap(
|
||||
@@ -722,7 +722,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
continue
|
||||
layer = checkpoint_layer - 2
|
||||
shards = []
|
||||
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
|
||||
with lazy_loader.use_custom_unpickler(lazy_loader.RestrictedUnpickler):
|
||||
for checkpoint_shard in range(checkpoint_shards):
|
||||
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
|
||||
for key in shards[0]:
|
||||
@@ -997,7 +997,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
model_spec[key] = spec["mtj"].copy()
|
||||
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
|
||||
|
||||
import torch_lazy_loader
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import functools
|
||||
@@ -1061,7 +1061,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
current_offset = model_dict[key].seek_offset
|
||||
spec = model_spec[model_spec_key]
|
||||
transforms = set(spec.get("transforms", ()))
|
||||
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
|
||||
if not isinstance(model_dict[key], lazy_loader.LazyTensor):
|
||||
error = f"Duplicate key {repr(key)}"
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
@@ -1141,7 +1141,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
import shutil
|
||||
shutil.move(koboldai_vars.model.replace('/', '_'), "models/{}".format(koboldai_vars.model.replace('/', '_')))
|
||||
print("\n", flush=True)
|
||||
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
|
||||
with lazy_loader.use_lazy_load(callback=callback, dematerialized_modules=True):
|
||||
if(os.path.isdir(koboldai_vars.custmodpth)):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", use_fast=False)
|
||||
|
Reference in New Issue
Block a user