Fix loading bar

This commit is contained in:
somebody
2023-06-21 16:27:22 -05:00
parent aca2b532d7
commit c56214c275
2 changed files with 11 additions and 5 deletions

View File

@@ -159,7 +159,7 @@ class TorchLazyTensor(LazyTensor):
): ):
# Cache miss. Assuming weights are loaded in order of # Cache miss. Assuming weights are loaded in order of
# (key, seek_offset), this means we need to invalidate the cache. # (key, seek_offset), this means we need to invalidate the cache.
print("!", end="", flush=True) # print("!", end="", flush=True)
CheckpointChunkCache.hit_data["misses"] += 1 CheckpointChunkCache.hit_data["misses"] += 1
CheckpointChunkCache.clear() CheckpointChunkCache.clear()
@@ -176,7 +176,7 @@ class TorchLazyTensor(LazyTensor):
) )
else: else:
# Cache hit. Hip hip hooray! :^) # Cache hit. Hip hip hooray! :^)
print(".", end="", flush=True) # print(".", end="", flush=True)
CheckpointChunkCache.hit_data["hits"] += 1 CheckpointChunkCache.hit_data["hits"] += 1
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
@@ -481,6 +481,11 @@ def post_load_cleanup() -> None:
logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}") logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}")
CheckpointChunkCache.clear(unload_model=True) CheckpointChunkCache.clear(unload_model=True)
# Bar is initialized in
# patches.patch_transformers_for_lazyload._load_state_dict_into_meta_model,
# as it has access to the state dict (for getting tensor count)
utils.bar = None
for v in torch_checkpoint_file_handles.values(): for v in torch_checkpoint_file_handles.values():
v.close() v.close()
torch_checkpoint_file_handles = {} torch_checkpoint_file_handles = {}

View File

@@ -1,9 +1,8 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
from ctypes import Union
import requests import requests
from typing import Iterable, List, Optional from typing import List
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
@@ -14,7 +13,6 @@ from transformers import (
from modeling.lazy_loader import LazyTensor from modeling.lazy_loader import LazyTensor
import utils import utils
from logger import logger
def patch_transformers_download(): def patch_transformers_download():
@@ -203,6 +201,8 @@ def patch_transformers_for_lazyload() -> None:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
# BEGIN PATCH # BEGIN PATCH
utils.bar = tqdm(total=len(state_dict), desc="Loading model tensors", file=utils.UIProgressBarFile())
for param_name, param in sorted( for param_name, param in sorted(
state_dict.items(), state_dict.items(),
# State dict must be ordered in this manner to make the caching in # State dict must be ordered in this manner to make the caching in
@@ -217,6 +217,7 @@ def patch_transformers_for_lazyload() -> None:
if isinstance(param, LazyTensor): if isinstance(param, LazyTensor):
# Should always be true # Should always be true
param = param.materialize() param = param.materialize()
utils.bar.update(1)
# END PATCH # END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys. # First part of the test is always true as load_state_dict_keys always contains state_dict keys.