mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Basic
This commit is contained in:
@@ -358,16 +358,19 @@ def safetensors_load_tensor_independently(
|
||||
) -> torch.Tensor:
|
||||
"""A hacky way to load a tensor by itself and not mmap every single tensor
|
||||
or whatever is causing that big memory spike"""
|
||||
print("[ld]", tensor_key)
|
||||
|
||||
with safetensors.safe_open(checkpoint_file, framework="pt", device=device) as f:
|
||||
return f.get_tensor(tensor_key)
|
||||
|
||||
|
||||
def patch_safetensors(callback):
|
||||
print("Hi! We are patching safetensors")
|
||||
# Safetensors load patch
|
||||
import transformers
|
||||
|
||||
def safetensors_load(checkpoint_file: str) -> dict:
|
||||
print("LOAD NOW", safetensors_load)
|
||||
# Monkeypatch applied to safetensors.torch.load_file
|
||||
|
||||
if utils.koboldai_vars.hascuda:
|
||||
@@ -409,6 +412,7 @@ def patch_safetensors(callback):
|
||||
return tensors
|
||||
|
||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||
safetensors.torch.load_file = safetensors_load
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -520,6 +524,7 @@ def use_lazy_load(
|
||||
old_torch_load = torch.load
|
||||
|
||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
print("TORCHLOAD", f)
|
||||
model_dict = old_torch_load(
|
||||
f=f,
|
||||
map_location=map_location,
|
||||
|
Reference in New Issue
Block a user