diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index b6ea1623..c14c7967 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -159,7 +159,7 @@ class TorchLazyTensor(LazyTensor): ): # Cache miss. Assuming weights are loaded in order of # (key, seek_offset), this means we need to invalidate the cache. - print("!", end="", flush=True) + # print("!", end="", flush=True) CheckpointChunkCache.hit_data["misses"] += 1 CheckpointChunkCache.clear() @@ -176,7 +176,7 @@ class TorchLazyTensor(LazyTensor): ) else: # Cache hit. Hip hip hooray! :^) - print(".", end="", flush=True) + # print(".", end="", flush=True) CheckpointChunkCache.hit_data["hits"] += 1 size = reduce(lambda x, y: x * y, self.shape, 1) @@ -481,6 +481,11 @@ def post_load_cleanup() -> None: logger.debug(f"[lazy_load] CheckpointChunkCache Hit Data: {CheckpointChunkCache.hit_data}") CheckpointChunkCache.clear(unload_model=True) + # Bar is initialized in + # patches.patch_transformers_for_lazyload._load_state_dict_into_meta_model, + # as it has access to the state dict (for getting tensor count) + utils.bar = None + for v in torch_checkpoint_file_handles.values(): v.close() torch_checkpoint_file_handles = {} diff --git a/modeling/patches.py b/modeling/patches.py index bd5a6325..71319ef8 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -1,9 +1,8 @@ from __future__ import annotations import copy -from ctypes import Union import requests -from typing import Iterable, List, Optional +from typing import List from tqdm.auto import tqdm import transformers @@ -14,7 +13,6 @@ from transformers import ( from modeling.lazy_loader import LazyTensor import utils -from logger import logger def patch_transformers_download(): @@ -203,6 +201,8 @@ def patch_transformers_for_lazyload() -> None: state_dict[new_key] = state_dict.pop(old_key) # BEGIN PATCH + utils.bar = tqdm(total=len(state_dict), desc="Loading model tensors", file=utils.UIProgressBarFile()) + for param_name, param in sorted( state_dict.items(), # State dict must be ordered in this manner to make the caching in @@ -217,6 +217,7 @@ def patch_transformers_for_lazyload() -> None: if isinstance(param, LazyTensor): # Should always be true param = param.materialize() + utils.bar.update(1) # END PATCH # First part of the test is always true as load_state_dict_keys always contains state_dict keys.