mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Speed
This commit is contained in:
@@ -46,7 +46,6 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import contextlib
|
||||
from functools import reduce
|
||||
import itertools
|
||||
import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
@@ -54,8 +53,7 @@ import numpy as np
|
||||
import collections
|
||||
import _codecs
|
||||
import os
|
||||
from torch.nn import Module
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||
# support it yet.
|
||||
@@ -85,6 +83,22 @@ STORAGE_TYPE_MAP = {
|
||||
# Storage of zipfile handles for each shard
|
||||
torch_checkpoint_file_handles = {}
|
||||
|
||||
|
||||
class CheckpointChunkCache:
|
||||
"""Storage for common checkpoint weight files to speed up loading. In order
|
||||
for this to be effective at all, weights must be loaded in ascending order
|
||||
of (key, seek_offset)."""
|
||||
file_name = None
|
||||
key = None
|
||||
handle = None
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls.file_name = None
|
||||
cls.key = None
|
||||
cls.handle = None
|
||||
|
||||
|
||||
class LazyTensor:
|
||||
pass
|
||||
|
||||
@@ -121,36 +135,37 @@ class TorchLazyTensor(LazyTensor):
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None,
|
||||
map_location=None,
|
||||
no_grad=True,
|
||||
filename="pytorch_model.bin",
|
||||
) -> torch.Tensor:
|
||||
|
||||
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
||||
|
||||
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
|
||||
|
||||
# if f not in torch_tensor_container_file_map:
|
||||
# torch_tensor_container_file_map[f] = []
|
||||
# Most of the operations are just seeks, let's see if we can optimize that.
|
||||
if (
|
||||
CheckpointChunkCache.file_name != filename
|
||||
or CheckpointChunkCache.key != self.key
|
||||
or not CheckpointChunkCache.handle
|
||||
):
|
||||
# Flush cache if invalid
|
||||
print("!", end="", flush=True)
|
||||
|
||||
# with zipfile.ZipFile(f, "r") as z:
|
||||
# paths = z.namelist()
|
||||
if CheckpointChunkCache.handle:
|
||||
CheckpointChunkCache.handle.close()
|
||||
|
||||
# for name in paths:
|
||||
# val = name.split("/data/")[-1]
|
||||
# if not val.isdecimal():
|
||||
# continue
|
||||
# torch_tensor_container_file_map[f].append(int(val))
|
||||
# torch_tensor_container_file_map[f].sort()
|
||||
# print(torch_tensor_container_file_map)
|
||||
CheckpointChunkCache.file_name = filename
|
||||
CheckpointChunkCache.key = self.key
|
||||
try:
|
||||
CheckpointChunkCache.handle = checkpoint.open(
|
||||
f"archive/data/{self.key}", "r"
|
||||
)
|
||||
except KeyError:
|
||||
CheckpointChunkCache.handle = checkpoint.open(
|
||||
f"{filename}/data/{self.key}", "r"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if not checkpoint:
|
||||
checkpoint = torch_checkpoint_file_handles[self.file_name]
|
||||
filename = self.file_name
|
||||
|
||||
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
|
||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||
dtype = self.dtype
|
||||
nbytes = (
|
||||
@@ -163,21 +178,12 @@ class TorchLazyTensor(LazyTensor):
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
try:
|
||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||
except:
|
||||
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
|
||||
f.seek(self.seek_offset, os.SEEK_CUR)
|
||||
# f.read(self.seek_offset)
|
||||
else:
|
||||
f = checkpoint
|
||||
assert isinstance(checkpoint, zipfile.ZipFile)
|
||||
|
||||
try:
|
||||
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
|
||||
finally:
|
||||
if isinstance(checkpoint, zipfile.ZipFile):
|
||||
f.close()
|
||||
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
|
||||
storage = STORAGE_TYPE_MAP[dtype].from_buffer(
|
||||
CheckpointChunkCache.handle.read(nbytes), "little"
|
||||
)
|
||||
|
||||
storage = torch.serialization._get_restore_location(map_location)(
|
||||
storage, self.location
|
||||
@@ -277,9 +283,7 @@ class _LazyUnpickler(RestrictedUnpickler):
|
||||
typename == "storage"
|
||||
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, _ = saved_id[1:]
|
||||
return TorchLazyTensor(
|
||||
storage_type, key, location
|
||||
)
|
||||
return TorchLazyTensor(storage_type, key, location)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
retval = super().load(*args, **kwargs)
|
||||
@@ -361,13 +365,6 @@ def patch_safetensors(callback):
|
||||
|
||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||
|
||||
def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor):
|
||||
with zipfile.ZipFile(file, "r") as z:
|
||||
storage_key = lazy_tensor.key
|
||||
ziproot = z.namelist()[0].split("/")[0]
|
||||
f = z.open(f"{ziproot}/data/{storage_key}")
|
||||
# TODO: Maybe some file seeking
|
||||
return f
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||
@@ -418,7 +415,7 @@ def use_lazy_load(
|
||||
if f not in torch_checkpoint_file_handles:
|
||||
torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r")
|
||||
|
||||
for k,v in model_dict.items():
|
||||
for k, v in model_dict.items():
|
||||
v.file_name = f
|
||||
|
||||
if callback is not None:
|
||||
|
@@ -126,7 +126,6 @@ def patch_transformers_generation() -> None:
|
||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||
|
||||
|
||||
CURRENT_CHECKPOINT = None
|
||||
def patch_transformers_for_lazyload() -> None:
|
||||
import torch
|
||||
import inspect
|
||||
@@ -158,8 +157,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
|
||||
"""
|
||||
|
||||
print("DEVMAP", device_map)
|
||||
|
||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||
# - deepspeed zero 3 support
|
||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||
@@ -186,13 +183,22 @@ def patch_transformers_for_lazyload() -> None:
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
# BEGIN PATCH
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
# State dict must be ordered in this manner to make the caching in
|
||||
# lazy_loader.py effective
|
||||
key=lambda x: (
|
||||
# NOTE: Assuming key is just decimal
|
||||
int(x[1].key),
|
||||
x[1].seek_offset,
|
||||
),
|
||||
):
|
||||
|
||||
# BEGIN PATCH
|
||||
if isinstance(param, LazyTensor):
|
||||
print(".", end="", flush=True)
|
||||
param = param.materialize()
|
||||
# END PATCH
|
||||
# END PATCH
|
||||
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
if (
|
||||
|
Reference in New Issue
Block a user