mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Use safetensors only when available
This commit is contained in:
@@ -54,10 +54,17 @@ import numpy as np
|
|||||||
import collections
|
import collections
|
||||||
import _codecs
|
import _codecs
|
||||||
import os
|
import os
|
||||||
import safetensors
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||||
|
# support it yet.
|
||||||
|
try:
|
||||||
|
import safetensors
|
||||||
|
HAS_SAFETENSORS = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
HAS_SAFETENSORS = False
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
|
|
||||||
@@ -382,6 +389,51 @@ def safetensors_load_tensor_independently(
|
|||||||
return f.get_tensor(tensor_key)
|
return f.get_tensor(tensor_key)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_safetensors():
|
||||||
|
# Safetensors load patch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
def safetensors_load(checkpoint_file: str) -> dict:
|
||||||
|
# Monkeypatch applied to safetensors.torch.load_file
|
||||||
|
|
||||||
|
if utils.koboldai_vars.hascuda:
|
||||||
|
# Use GPU as intermediary whenever possible, lowers RAM usage
|
||||||
|
# by a significant amount while making loading slightly slower
|
||||||
|
# (70 tensors/s -> 65 tensor/s). The memory savings probably
|
||||||
|
# shouldn't be the happening, maybe there's a memory leak
|
||||||
|
# somewhere in our pipeline with CPU tensors.
|
||||||
|
intermediary_device = "cuda"
|
||||||
|
else:
|
||||||
|
intermediary_device = "cpu"
|
||||||
|
|
||||||
|
tensors = {}
|
||||||
|
|
||||||
|
with safetensors.safe_open(
|
||||||
|
checkpoint_file, framework="pt", device=intermediary_device,
|
||||||
|
) as f:
|
||||||
|
for key in f.keys():
|
||||||
|
tensors[key] = None
|
||||||
|
|
||||||
|
for key in tensors.keys():
|
||||||
|
|
||||||
|
tensors[key] = SafetensorsLazyTensor(
|
||||||
|
checkpoint_file=checkpoint_file, key=key, location=intermediary_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback(
|
||||||
|
tensors,
|
||||||
|
f=checkpoint_file,
|
||||||
|
map_location=None,
|
||||||
|
pickle_module=pickle,
|
||||||
|
is_safetensors=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||||
try:
|
try:
|
||||||
@@ -441,48 +493,8 @@ def use_lazy_load(
|
|||||||
|
|
||||||
torch.load = torch_load
|
torch.load = torch_load
|
||||||
|
|
||||||
# Safetensors load patch
|
if HAS_SAFETENSORS:
|
||||||
import transformers
|
patch_safetensors()
|
||||||
|
|
||||||
def safetensors_load(checkpoint_file: str) -> dict:
|
|
||||||
# Monkeypatch applied to safetensors.torch.load_file
|
|
||||||
|
|
||||||
if utils.koboldai_vars.hascuda:
|
|
||||||
# Use GPU as intermediary whenever possible, lowers RAM usage
|
|
||||||
# by a significant amount while making loading slightly slower
|
|
||||||
# (70 tensors/s -> 65 tensor/s). The memory savings probably
|
|
||||||
# shouldn't be the happening, maybe there's a memory leak
|
|
||||||
# somewhere in our pipeline with CPU tensors.
|
|
||||||
intermediary_device = "cuda"
|
|
||||||
else:
|
|
||||||
intermediary_device = "cpu"
|
|
||||||
|
|
||||||
tensors = {}
|
|
||||||
|
|
||||||
with safetensors.safe_open(
|
|
||||||
checkpoint_file, framework="pt", device=intermediary_device,
|
|
||||||
) as f:
|
|
||||||
for key in f.keys():
|
|
||||||
tensors[key] = None
|
|
||||||
|
|
||||||
for key in tensors.keys():
|
|
||||||
|
|
||||||
tensors[key] = SafetensorsLazyTensor(
|
|
||||||
checkpoint_file=checkpoint_file, key=key, location=intermediary_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if callback is not None:
|
|
||||||
callback(
|
|
||||||
tensors,
|
|
||||||
f=checkpoint_file,
|
|
||||||
map_location=None,
|
|
||||||
pickle_module=pickle,
|
|
||||||
is_safetensors=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tensors
|
|
||||||
|
|
||||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
if use_accelerate_init_empty_weights:
|
if use_accelerate_init_empty_weights:
|
||||||
|
Reference in New Issue
Block a user