mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix loading bar
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from ctypes import Union
|
||||
import requests
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import List
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
@@ -14,7 +13,6 @@ from transformers import (
|
||||
from modeling.lazy_loader import LazyTensor
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
|
||||
def patch_transformers_download():
|
||||
@@ -203,6 +201,8 @@ def patch_transformers_for_lazyload() -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# BEGIN PATCH
|
||||
utils.bar = tqdm(total=len(state_dict), desc="Loading model tensors", file=utils.UIProgressBarFile())
|
||||
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
# State dict must be ordered in this manner to make the caching in
|
||||
@@ -217,6 +217,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
if isinstance(param, LazyTensor):
|
||||
# Should always be true
|
||||
param = param.materialize()
|
||||
utils.bar.update(1)
|
||||
# END PATCH
|
||||
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
|
Reference in New Issue
Block a user