This commit is contained in:
somebody
2023-05-28 13:03:24 -05:00
parent 6f93150e4d
commit 14241fc156
2 changed files with 58 additions and 55 deletions

View File

@@ -46,7 +46,6 @@ POSSIBILITY OF SUCH DAMAGE.
import contextlib import contextlib
from functools import reduce from functools import reduce
import itertools
import zipfile import zipfile
import pickle import pickle
import torch import torch
@@ -54,8 +53,7 @@ import numpy as np
import collections import collections
import _codecs import _codecs
import os import os
from torch.nn import Module from typing import Any, Callable, Dict, Optional, Tuple, Type
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
# Safetensors is a dependency for the local version, TPU/Colab doesn't # Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet. # support it yet.
@@ -85,6 +83,22 @@ STORAGE_TYPE_MAP = {
# Storage of zipfile handles for each shard # Storage of zipfile handles for each shard
torch_checkpoint_file_handles = {} torch_checkpoint_file_handles = {}
class CheckpointChunkCache:
"""Storage for common checkpoint weight files to speed up loading. In order
for this to be effective at all, weights must be loaded in ascending order
of (key, seek_offset)."""
file_name = None
key = None
handle = None
@classmethod
def clear(cls) -> None:
cls.file_name = None
cls.key = None
cls.handle = None
class LazyTensor: class LazyTensor:
pass pass
@@ -121,36 +135,37 @@ class TorchLazyTensor(LazyTensor):
def materialize( def materialize(
self, self,
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None,
map_location=None, map_location=None,
no_grad=True, no_grad=True,
filename="pytorch_model.bin",
) -> torch.Tensor: ) -> torch.Tensor:
checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = os.path.basename(os.path.normpath(self.file_name)).split(".")[0]
# if f not in torch_tensor_container_file_map: # Most of the operations are just seeks, let's see if we can optimize that.
# torch_tensor_container_file_map[f] = [] if (
CheckpointChunkCache.file_name != filename
or CheckpointChunkCache.key != self.key
or not CheckpointChunkCache.handle
):
# Flush cache if invalid
print("!", end="", flush=True)
# with zipfile.ZipFile(f, "r") as z: if CheckpointChunkCache.handle:
# paths = z.namelist() CheckpointChunkCache.handle.close()
# for name in paths: CheckpointChunkCache.file_name = filename
# val = name.split("/data/")[-1] CheckpointChunkCache.key = self.key
# if not val.isdecimal(): try:
# continue CheckpointChunkCache.handle = checkpoint.open(
# torch_tensor_container_file_map[f].append(int(val)) f"archive/data/{self.key}", "r"
# torch_tensor_container_file_map[f].sort() )
# print(torch_tensor_container_file_map) except KeyError:
CheckpointChunkCache.handle = checkpoint.open(
f"{filename}/data/{self.key}", "r"
)
if not checkpoint:
checkpoint = torch_checkpoint_file_handles[self.file_name]
filename = self.file_name
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype dtype = self.dtype
nbytes = ( nbytes = (
@@ -163,21 +178,12 @@ class TorchLazyTensor(LazyTensor):
) )
) )
if isinstance(checkpoint, zipfile.ZipFile): assert isinstance(checkpoint, zipfile.ZipFile)
try:
f = checkpoint.open(f"archive/data/{self.key}", "r")
except:
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
f.seek(self.seek_offset, os.SEEK_CUR)
# f.read(self.seek_offset)
else:
f = checkpoint
try: CheckpointChunkCache.handle.seek(self.seek_offset, os.SEEK_SET)
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little") storage = STORAGE_TYPE_MAP[dtype].from_buffer(
finally: CheckpointChunkCache.handle.read(nbytes), "little"
if isinstance(checkpoint, zipfile.ZipFile): )
f.close()
storage = torch.serialization._get_restore_location(map_location)( storage = torch.serialization._get_restore_location(map_location)(
storage, self.location storage, self.location
@@ -277,9 +283,7 @@ class _LazyUnpickler(RestrictedUnpickler):
typename == "storage" typename == "storage"
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, _ = saved_id[1:] storage_type, key, location, _ = saved_id[1:]
return TorchLazyTensor( return TorchLazyTensor(storage_type, key, location)
storage_type, key, location
)
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
retval = super().load(*args, **kwargs) retval = super().load(*args, **kwargs)
@@ -361,13 +365,6 @@ def patch_safetensors(callback):
transformers.modeling_utils.safe_load_file = safetensors_load transformers.modeling_utils.safe_load_file = safetensors_load
def get_torch_tensor_file(file: str, lazy_tensor: TorchLazyTensor):
with zipfile.ZipFile(file, "r") as z:
storage_key = lazy_tensor.key
ziproot = z.namelist()[0].split("/")[0]
f = z.open(f"{ziproot}/data/{storage_key}")
# TODO: Maybe some file seeking
return f
@contextlib.contextmanager @contextlib.contextmanager
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler): def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
@@ -418,7 +415,7 @@ def use_lazy_load(
if f not in torch_checkpoint_file_handles: if f not in torch_checkpoint_file_handles:
torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r") torch_checkpoint_file_handles[f] = zipfile.ZipFile(f, "r")
for k,v in model_dict.items(): for k, v in model_dict.items():
v.file_name = f v.file_name = f
if callback is not None: if callback is not None:

View File

@@ -126,7 +126,6 @@ def patch_transformers_generation() -> None:
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
CURRENT_CHECKPOINT = None
def patch_transformers_for_lazyload() -> None: def patch_transformers_for_lazyload() -> None:
import torch import torch
import inspect import inspect
@@ -158,8 +157,6 @@ def patch_transformers_for_lazyload() -> None:
""" """
print("DEVMAP", device_map)
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support # - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model # - need to copy metadata if any - see _load_state_dict_into_model
@@ -186,13 +183,22 @@ def patch_transformers_for_lazyload() -> None:
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
for param_name, param in state_dict.items(): # BEGIN PATCH
for param_name, param in sorted(
state_dict.items(),
# State dict must be ordered in this manner to make the caching in
# lazy_loader.py effective
key=lambda x: (
# NOTE: Assuming key is just decimal
int(x[1].key),
x[1].seek_offset,
),
):
# BEGIN PATCH
if isinstance(param, LazyTensor): if isinstance(param, LazyTensor):
print(".", end="", flush=True) print(".", end="", flush=True)
param = param.materialize() param = param.materialize()
# END PATCH # END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys. # First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if ( if (