This commit is contained in:
somebody
2023-05-28 13:03:24 -05:00
parent 6f93150e4d
commit 14241fc156
2 changed files with 58 additions and 55 deletions

View File

@@ -46,7 +46,6 @@ POSSIBILITY OF SUCH DAMAGE.
import contextlib
from functools import reduce
import itertools
import zipfile
import pickle
import torch
@@ -54,8 +53,7 @@ import numpy as np
import collections
import _codecs
import os
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type
# Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet.
@@ -85,6 +83,22 @@ STORAGE_TYPE_MAP = {
# Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {}
class CheckpointChunkCache:
"""Storage for common checkpoint weight files to speed up loading. In order
for this to be effective at all, weights must be loaded in ascending order
of (key, seek_offset)."""
file_name = None
key = None
handle = None
@classmethod
def clear(cls) -> None:
cls.file_name = None
cls.key = None
cls.handle = None
class LazyTensor:
pass
@@ -121,36 +135,37 @@ class TorchLazyTensor(LazyTensor):
def materialize(
self,
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None,
map_location=None,
no_grad=True,
filename="pytorch_model.bin",
) -> torch.Tensor:
checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
# if f not in torch_tensor_container_file_map:
# torch_tensor_container_file_map[f] = []
# Most of the operations are just seeks, let's see if we can optimize that.
if (
CheckpointChunkCache.file_name != filename
or CheckpointChunkCache.key != self.key
or not CheckpointChunkCache.handle
):
# Flush cache if invalid
print("!", end="", flush=True)
# with zipfile.ZipFile(f, "r") as z:
# paths = z.namelist()
if CheckpointChunkCache.handle:
CheckpointChunkCache.handle.close()
# for name in paths:
# val = name.split("/data/")[-1]
# if not val.isdecimal():
# continue
# torch_tensor_container_file_map[f].append(int(val))
# torch_tensor_container_file_map[f].sort()
# print(torch_tensor_container_file_map)
CheckpointChunkCache.file_name = filename
CheckpointChunkCache.key = self.key
try:
CheckpointChunkCache.handle = checkpoint.open(
f"archive/data/{self.key}", "r"
)
except KeyError:
CheckpointChunkCache.handle = checkpoint.open(
f"{filename}/data/{self.key}", "r"
)
if not checkpoint:
checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = self.file_name
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype
nbytes = (
@@ -163,21 +178,12 @@ class TorchLazyTensor(LazyTensor):
)
)
if isinstance(checkpoint, zipfile.ZipFile):
try:
f = checkpoint.open(f"archive/data/{self.key}", "r")
except:
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
f.seek(self.seek_offset, os.SEEK_CUR)
# f.read(self.seek_offset)
else:
f = checkpoint
assert isinstance(checkpoint, zipfile.ZipFile)
try:
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
finally:
if isinstance(checkpoint, zipfile.ZipFile):
f.close()
CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
storage = STORAGE_TYPE_MAP[dtype].from_buffer(
CheckpointChunkCache.handle.read(nbytes), "little"
)
storage = torch.serialization._get_restore_location(map_location)(
storage, self.location
@@ -277,9 +283,7 @@ class _LazyUnpickler(RestrictedUnpickler):
typename == "storage"
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, _ = saved_id[1:]
return TorchLazyTensor(
storage_type, key, location
)
return TorchLazyTensor(storage_type, key, location)
def load(self, *args, **kwargs):
retval = super().load(*args, **kwargs)
@@ -361,13 +365,6 @@ def patch_safetensors(callback):
transformers.modeling_utils.safe_load_file = safetensors_load
def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor):
with zipfile.ZipFile(file, "r") as z:
storage_key = lazy_tensor.key
ziproot = z.namelist()[0].split("/")[0]
f = z.open(f"{ziproot}/data/{storage_key}")
# TODO: Maybe some file seeking
return f
@contextlib.contextmanager
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
@@ -418,7 +415,7 @@ def use_lazy_load(
if f not in torch_checkpoint_file_handles:
torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r")
for k,v in model_dict.items():
for k, v in model_dict.items():
v.file_name = f
if callback is not None:

View File

@@ -126,7 +126,6 @@ def patch_transformers_generation() -> None:
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
CURRENT_CHECKPOINT = None
def patch_transformers_for_lazyload() -> None:
import torch
import inspect
@@ -158,8 +157,6 @@ def patch_transformers_for_lazyload() -> None:
"""
print("DEVMAP", device_map)
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
@@ -186,13 +183,22 @@ def patch_transformers_for_lazyload() -> None:
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
for param_name, param in state_dict.items():
# BEGIN PATCH
for param_name, param in sorted(
state_dict.items(),
# State dict must be ordered in this manner to make the caching in
# lazy_loader.py effective
key=lambda x: (
# NOTE: Assuming key is just decimal
int(x[1].key),
x[1].seek_offset,
),
):
# BEGIN PATCH
if isinstance(param, LazyTensor):
print(".", end="", flush=True)
param = param.materialize()
# END PATCH
# END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if (