mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Speed
This commit is contained in:
@@ -46,7 +46,6 @@ POSSIBILITY OF SUCH DAMAGE.
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import itertools
|
|
||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
@@ -54,8 +53,7 @@ import numpy as np
|
|||||||
import collections
|
import collections
|
||||||
import _codecs
|
import _codecs
|
||||||
import os
|
import os
|
||||||
from torch.nn import Module
|
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||||
# support it yet.
|
# support it yet.
|
||||||
@@ -85,6 +83,22 @@ STORAGE_TYPE_MAP = {
|
|||||||
# Storage of zipfile handles for each shard
|
# Storage of zipfile handles for each shard
|
||||||
torch_checkpoint_file_handles = {}
|
torch_checkpoint_file_handles = {}
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointChunkCache:
|
||||||
|
"""Storage for common checkpoint weight files to speed up loading. In order
|
||||||
|
for this to be effective at all, weights must be loaded in ascending order
|
||||||
|
of (key, seek_offset)."""
|
||||||
|
file_name = None
|
||||||
|
key = None
|
||||||
|
handle = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear(cls) -> None:
|
||||||
|
cls.file_name = None
|
||||||
|
cls.key = None
|
||||||
|
cls.handle = None
|
||||||
|
|
||||||
|
|
||||||
class LazyTensor:
|
class LazyTensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -121,36 +135,37 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
|
|
||||||
def materialize(
|
def materialize(
|
||||||
self,
|
self,
|
||||||
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None,
|
|
||||||
map_location=None,
|
map_location=None,
|
||||||
no_grad=True,
|
no_grad=True,
|
||||||
filename="pytorch_model.bin",
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# if f not in torch_tensor_container_file_map:
|
|
||||||
# torch_tensor_container_file_map[f] = []
|
|
||||||
|
|
||||||
# with zipfile.ZipFile(f, "r") as z:
|
|
||||||
# paths = z.namelist()
|
|
||||||
|
|
||||||
# for name in paths:
|
|
||||||
# val = name.split("/data/")[-1]
|
|
||||||
# if not val.isdecimal():
|
|
||||||
# continue
|
|
||||||
# torch_tensor_container_file_map[f].append(int(val))
|
|
||||||
# torch_tensor_container_file_map[f].sort()
|
|
||||||
# print(torch_tensor_container_file_map)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if not checkpoint:
|
|
||||||
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
||||||
filename = self.file_name
|
|
||||||
|
|
||||||
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
|
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
|
||||||
|
|
||||||
|
# Most of the operations are just seeks, let's see if we can optimize that.
|
||||||
|
if (
|
||||||
|
CheckpointChunkCache.file_name != filename
|
||||||
|
or CheckpointChunkCache.key != self.key
|
||||||
|
or not CheckpointChunkCache.handle
|
||||||
|
):
|
||||||
|
# Flush cache if invalid
|
||||||
|
print("!", end="", flush=True)
|
||||||
|
|
||||||
|
if CheckpointChunkCache.handle:
|
||||||
|
CheckpointChunkCache.handle.close()
|
||||||
|
|
||||||
|
CheckpointChunkCache.file_name = filename
|
||||||
|
CheckpointChunkCache.key = self.key
|
||||||
|
try:
|
||||||
|
CheckpointChunkCache.handle = checkpoint.open(
|
||||||
|
f"archive/data/{self.key}", "r"
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
CheckpointChunkCache.handle = checkpoint.open(
|
||||||
|
f"{filename}/data/{self.key}", "r"
|
||||||
|
)
|
||||||
|
|
||||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||||
dtype = self.dtype
|
dtype = self.dtype
|
||||||
nbytes = (
|
nbytes = (
|
||||||
@@ -163,21 +178,12 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(checkpoint, zipfile.ZipFile):
|
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||||
try:
|
|
||||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
|
||||||
except:
|
|
||||||
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
|
|
||||||
f.seek(self.seek_offset, os.SEEK_CUR)
|
|
||||||
# f.read(self.seek_offset)
|
|
||||||
else:
|
|
||||||
f = checkpoint
|
|
||||||
|
|
||||||
try:
|
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||||
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
|
storage = STORAGE_TYPE_MAP[dtype].from_buffer(
|
||||||
finally:
|
CheckpointChunkCache.handle.read(nbytes), "little"
|
||||||
if isinstance(checkpoint, zipfile.ZipFile):
|
)
|
||||||
f.close()
|
|
||||||
|
|
||||||
storage = torch.serialization._get_restore_location(map_location)(
|
storage = torch.serialization._get_restore_location(map_location)(
|
||||||
storage, self.location
|
storage, self.location
|
||||||
@@ -277,9 +283,7 @@ class _LazyUnpickler(RestrictedUnpickler):
|
|||||||
typename == "storage"
|
typename == "storage"
|
||||||
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||||
storage_type, key, location, _ = saved_id[1:]
|
storage_type, key, location, _ = saved_id[1:]
|
||||||
return TorchLazyTensor(
|
return TorchLazyTensor(storage_type, key, location)
|
||||||
storage_type, key, location
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, *args, **kwargs):
|
def load(self, *args, **kwargs):
|
||||||
retval = super().load(*args, **kwargs)
|
retval = super().load(*args, **kwargs)
|
||||||
@@ -361,13 +365,6 @@ def patch_safetensors(callback):
|
|||||||
|
|
||||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||||
|
|
||||||
def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor):
|
|
||||||
with zipfile.ZipFile(file, "r") as z:
|
|
||||||
storage_key = lazy_tensor.key
|
|
||||||
ziproot = z.namelist()[0].split("/")[0]
|
|
||||||
f = z.open(f"{ziproot}/data/{storage_key}")
|
|
||||||
# TODO: Maybe some file seeking
|
|
||||||
return f
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||||
@@ -418,7 +415,7 @@ def use_lazy_load(
|
|||||||
if f not in torch_checkpoint_file_handles:
|
if f not in torch_checkpoint_file_handles:
|
||||||
torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r")
|
torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r")
|
||||||
|
|
||||||
for k,v in model_dict.items():
|
for k, v in model_dict.items():
|
||||||
v.file_name = f
|
v.file_name = f
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
|
@@ -126,7 +126,6 @@ def patch_transformers_generation() -> None:
|
|||||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||||
|
|
||||||
|
|
||||||
CURRENT_CHECKPOINT = None
|
|
||||||
def patch_transformers_for_lazyload() -> None:
|
def patch_transformers_for_lazyload() -> None:
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
@@ -158,8 +157,6 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("DEVMAP", device_map)
|
|
||||||
|
|
||||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||||
# - deepspeed zero 3 support
|
# - deepspeed zero 3 support
|
||||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||||
@@ -186,13 +183,22 @@ def patch_transformers_for_lazyload() -> None:
|
|||||||
for old_key, new_key in zip(old_keys, new_keys):
|
for old_key, new_key in zip(old_keys, new_keys):
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
for param_name, param in state_dict.items():
|
# BEGIN PATCH
|
||||||
|
for param_name, param in sorted(
|
||||||
|
state_dict.items(),
|
||||||
|
# State dict must be ordered in this manner to make the caching in
|
||||||
|
# lazy_loader.py effective
|
||||||
|
key=lambda x: (
|
||||||
|
# NOTE: Assuming key is just decimal
|
||||||
|
int(x[1].key),
|
||||||
|
x[1].seek_offset,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
# BEGIN PATCH
|
|
||||||
if isinstance(param, LazyTensor):
|
if isinstance(param, LazyTensor):
|
||||||
print(".", end="", flush=True)
|
print(".", end="", flush=True)
|
||||||
param = param.materialize()
|
param = param.materialize()
|
||||||
# END PATCH
|
# END PATCH
|
||||||
|
|
||||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||||
if (
|
if (
|
||||||
|
Reference in New Issue
Block a user