mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Lazyload: Safetensors
This commit is contained in:
@@ -20,15 +20,13 @@ from transformers import (
|
||||
GPTNeoForCausalLM,
|
||||
AutoModelForCausalLM,
|
||||
LogitsProcessorList,
|
||||
LogitsProcessor,
|
||||
)
|
||||
|
||||
import utils
|
||||
import torch_lazy_loader
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
from logger import logger, Colors
|
||||
|
||||
from modeling import warpers
|
||||
from modeling import inference_model
|
||||
from modeling.warpers import Warper
|
||||
from modeling.stoppers import Stoppers
|
||||
from modeling.post_token_hooks import PostTokenHooks
|
||||
@@ -274,7 +272,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
**tf_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Fell back for model due to", e)
|
||||
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
||||
|
||||
if "out of memory" in traceback.format_exc().lower():
|
||||
raise RuntimeError(
|
||||
@@ -291,6 +289,18 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.model.get_input_embeddings().embedding_dim
|
||||
|
||||
def _will_load_with_safetensors(self) -> bool:
|
||||
path = self.get_local_model_path()
|
||||
|
||||
# TODO: This might mess up download to run
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if not os.path.exists(os.path.join(path, "model.safetensors")):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _move_to_devices(self) -> None:
|
||||
if not utils.koboldai_vars.breakmodel:
|
||||
if utils.koboldai_vars.usegpu:
|
||||
@@ -391,8 +401,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||
|
||||
def lazy_load_callback(
|
||||
model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]],
|
||||
model_dict: Dict[str, Union[lazy_loader.LazyTensor, torch.Tensor]],
|
||||
f,
|
||||
is_safetensors: bool = False,
|
||||
**_,
|
||||
):
|
||||
if lazy_load_callback.nested:
|
||||
@@ -414,7 +425,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
for key, value in model_dict.items():
|
||||
original_key = get_original_key(key)
|
||||
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(
|
||||
|
||||
if isinstance(value, lazy_loader.LazyTensor) and not any(
|
||||
original_key.startswith(n) for n in utils.layers_module_names
|
||||
):
|
||||
device_map[key] = (
|
||||
@@ -483,59 +495,173 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
file=utils.UIProgressBarFile(),
|
||||
)
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
if not is_safetensors:
|
||||
# Torch lazyload
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
zipfolder = os.path.basename(os.path.normpath(f)).split(".")[0]
|
||||
f = None
|
||||
current_offset = 0
|
||||
able_to_pin_layers = True
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in sorted(
|
||||
device_map.keys(),
|
||||
key=lambda k: (
|
||||
model_dict[k].key,
|
||||
model_dict[k].seek_offset,
|
||||
),
|
||||
):
|
||||
storage_key = model_dict[key].key
|
||||
if (
|
||||
storage_key != last_storage_key
|
||||
or model_dict[key].seek_offset < current_offset
|
||||
):
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
try:
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
except:
|
||||
f = z.open(f"{zipfolder}/data/{storage_key}")
|
||||
current_offset = 0
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.read(model_dict[key].seek_offset - current_offset)
|
||||
current_offset = model_dict[key].seek_offset
|
||||
device = device_map[key]
|
||||
size = functools.reduce(
|
||||
lambda x, y: x * y, model_dict[key].shape, 1
|
||||
)
|
||||
dtype = model_dict[key].dtype
|
||||
nbytes = (
|
||||
size
|
||||
if dtype is torch.bool
|
||||
else size
|
||||
* (
|
||||
(
|
||||
torch.finfo
|
||||
if dtype.is_floating_point
|
||||
else torch.iinfo
|
||||
)(dtype).bits
|
||||
>> 3
|
||||
)
|
||||
)
|
||||
# print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
model_dict[key] = model_dict[key].materialize(
|
||||
f, map_location="cpu"
|
||||
)
|
||||
if model_dict[key].dtype is torch.float32:
|
||||
utils.koboldai_vars.fp32_model = True
|
||||
if (
|
||||
convert_to_float16
|
||||
and breakmodel.primary_device != "cpu"
|
||||
and utils.koboldai_vars.hascuda
|
||||
and (
|
||||
utils.koboldai_vars.breakmodel
|
||||
or utils.koboldai_vars.usegpu
|
||||
)
|
||||
and model_dict[key].dtype is torch.float32
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float16)
|
||||
if breakmodel.primary_device == "cpu" or (
|
||||
not utils.koboldai_vars.usegpu
|
||||
and not utils.koboldai_vars.breakmodel
|
||||
and model_dict[key].dtype is torch.float16
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float32)
|
||||
if device == "shared":
|
||||
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||
if able_to_pin_layers:
|
||||
try:
|
||||
model_dict[key] = model_dict[key].pin_memory()
|
||||
except:
|
||||
able_to_pin_layers = False
|
||||
elif device == "disk":
|
||||
accelerate.utils.offload_weight(
|
||||
model_dict[key],
|
||||
get_original_key(key),
|
||||
"accelerate-disk-cache",
|
||||
index=utils.offload_index,
|
||||
)
|
||||
model_dict[key] = model_dict[key].to("meta")
|
||||
else:
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
# print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
utils.bar.update(1)
|
||||
utils.koboldai_vars.loaded_layers += 1
|
||||
finally:
|
||||
if (
|
||||
utils.num_shards is None
|
||||
or utils.current_shard >= utils.num_shards
|
||||
):
|
||||
if utils.offload_index:
|
||||
for name, tensor in utils.named_buffers:
|
||||
dtype = tensor.dtype
|
||||
if (
|
||||
convert_to_float16
|
||||
and breakmodel.primary_device != "cpu"
|
||||
and utils.koboldai_vars.hascuda
|
||||
and (
|
||||
utils.koboldai_vars.breakmodel
|
||||
or utils.koboldai_vars.usegpu
|
||||
)
|
||||
):
|
||||
dtype = torch.float16
|
||||
if breakmodel.primary_device == "cpu" or (
|
||||
not utils.koboldai_vars.usegpu
|
||||
and not utils.koboldai_vars.breakmodel
|
||||
):
|
||||
dtype = torch.float32
|
||||
if (
|
||||
name in model_dict
|
||||
and model_dict[name].dtype is not dtype
|
||||
):
|
||||
model_dict[name] = model_dict[name].to(dtype)
|
||||
if tensor.dtype is not dtype:
|
||||
tensor = tensor.to(dtype)
|
||||
if name not in utils.offload_index:
|
||||
accelerate.utils.offload_weight(
|
||||
tensor,
|
||||
name,
|
||||
"accelerate-disk-cache",
|
||||
index=utils.offload_index,
|
||||
)
|
||||
accelerate.utils.save_offload_index(
|
||||
utils.offload_index, "accelerate-disk-cache"
|
||||
)
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
utils.koboldai_vars.status_message = ""
|
||||
lazy_load_callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
else:
|
||||
# Loading with safetensors
|
||||
try:
|
||||
last_storage_key = None
|
||||
zipfolder = os.path.basename(os.path.normpath(f)).split(".")[0]
|
||||
f = None
|
||||
current_offset = 0
|
||||
able_to_pin_layers = True
|
||||
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
|
||||
for key in sorted(
|
||||
device_map.keys(),
|
||||
key=lambda k: (model_dict[k].key, model_dict[k].seek_offset),
|
||||
key=lambda k: model_dict[k].key,
|
||||
):
|
||||
storage_key = model_dict[key].key
|
||||
if (
|
||||
storage_key != last_storage_key
|
||||
or model_dict[key].seek_offset < current_offset
|
||||
):
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
try:
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
except:
|
||||
f = z.open(f"{zipfolder}/data/{storage_key}")
|
||||
current_offset = 0
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.read(model_dict[key].seek_offset - current_offset)
|
||||
current_offset = model_dict[key].seek_offset
|
||||
|
||||
device = device_map[key]
|
||||
size = functools.reduce(
|
||||
lambda x, y: x * y, model_dict[key].shape, 1
|
||||
)
|
||||
dtype = model_dict[key].dtype
|
||||
nbytes = (
|
||||
size
|
||||
if dtype is torch.bool
|
||||
else size
|
||||
* (
|
||||
(
|
||||
torch.finfo
|
||||
if dtype.is_floating_point
|
||||
else torch.iinfo
|
||||
)(dtype).bits
|
||||
>> 3
|
||||
)
|
||||
)
|
||||
|
||||
# print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
|
||||
model_dict[key] = model_dict[key].materialize(
|
||||
f, map_location="cpu"
|
||||
)
|
||||
|
||||
if model_dict[key].dtype is torch.float32:
|
||||
utils.koboldai_vars.fp32_model = True
|
||||
|
||||
if (
|
||||
convert_to_float16
|
||||
and breakmodel.primary_device != "cpu"
|
||||
@@ -547,12 +673,14 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
and model_dict[key].dtype is torch.float32
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float16)
|
||||
|
||||
if breakmodel.primary_device == "cpu" or (
|
||||
not utils.koboldai_vars.usegpu
|
||||
and not utils.koboldai_vars.breakmodel
|
||||
and model_dict[key].dtype is torch.float16
|
||||
):
|
||||
model_dict[key] = model_dict[key].to(torch.float32)
|
||||
|
||||
if device == "shared":
|
||||
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||
if able_to_pin_layers:
|
||||
@@ -570,10 +698,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
model_dict[key] = model_dict[key].to("meta")
|
||||
else:
|
||||
model_dict[key] = model_dict[key].to(device)
|
||||
# print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
|
||||
utils.bar.update(1)
|
||||
utils.koboldai_vars.loaded_layers += 1
|
||||
|
||||
finally:
|
||||
if (
|
||||
utils.num_shards is None
|
||||
@@ -617,9 +745,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
utils.koboldai_vars.status_message = ""
|
||||
|
||||
lazy_load_callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
lazy_load_callback.nested = False
|
||||
return lazy_load_callback
|
||||
|
Reference in New Issue
Block a user