Fix loading bar

This commit is contained in:
somebody
2023-06-21 16:27:22 -05:00
parent aca2b532d7
commit c56214c275
2 changed files with 11 additions and 5 deletions

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
import copy
from ctypes import Union
import requests
from typing import Iterable, List, Optional
from typing import List
from tqdm.auto import tqdm
import transformers
@@ -14,7 +13,6 @@ from transformers import (
from modeling.lazy_loader import LazyTensor
import utils
from logger import logger
def patch_transformers_download():
@@ -203,6 +201,8 @@ def patch_transformers_for_lazyload() -> None:
state_dict[new_key] = state_dict.pop(old_key)
# BEGIN PATCH
utils.bar = tqdm(total=len(state_dict), desc="Loading model tensors", file=utils.UIProgressBarFile())
for param_name, param in sorted(
state_dict.items(),
# State dict must be ordered in this manner to make the caching in
@@ -217,6 +217,7 @@ def patch_transformers_for_lazyload() -> None:
if isinstance(param, LazyTensor):
# Should always be true
param = param.materialize()
utils.bar.update(1)
# END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.