This commit is contained in:
somebody
2023-05-28 13:03:24 -05:00
parent 6f93150e4d
commit 14241fc156
2 changed files with 58 additions and 55 deletions

View File

@@ -126,7 +126,6 @@ def patch_transformers_generation() -> None:
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
CURRENT_CHECKPOINT = None
def patch_transformers_for_lazyload() -> None:
import torch
import inspect
@@ -158,8 +157,6 @@ def patch_transformers_for_lazyload() -> None:
"""
print("DEVMAP", device_map)
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
@@ -186,13 +183,22 @@ def patch_transformers_for_lazyload() -> None:
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
for param_name, param in state_dict.items():
# BEGIN PATCH
for param_name, param in sorted(
state_dict.items(),
# State dict must be ordered in this manner to make the caching in
# lazy_loader.py effective
key=lambda x: (
# NOTE: Assuming key is just decimal
int(x[1].key),
x[1].seek_offset,
),
):
# BEGIN PATCH
if isinstance(param, LazyTensor):
print(".", end="", flush=True)
param = param.materialize()
# END PATCH
# END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if (