Added status bar for downloading models

This commit is contained in:
ebolam 2022-07-22 13:58:20 -04:00
parent a0475ba049
commit 907cf74b13
2 changed files with 63 additions and 1 deletions

View File

@ -1305,9 +1305,60 @@ def patch_causallm(model):
Embedding._koboldai_patch_causallm_model = model Embedding._koboldai_patch_causallm_model = model
return model return model
def patch_transformers_download():
global transformers
import copy, requests, tqdm, time
class Send_to_socketio(object):
def write(self, bar):
bar = bar.replace("\r", "")
try:
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
eventlet.sleep(seconds=0)
except:
pass
def http_get(
url: str,
temp_file: transformers.utils.hub.BinaryIO,
proxies=None,
resume_size=0,
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
file_name: transformers.utils.hub.Optional[str] = None,
):
"""
Download remote file. Do not gobble up errors.
"""
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = f"bytes={resume_size}-"
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
transformers.utils.hub._raise_for_status(r)
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
progress = tqdm.tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=total,
initial=resume_size,
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
file=Send_to_socketio(),
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
transformers.utils.hub.http_get = http_get
def patch_transformers(): def patch_transformers():
global transformers global transformers
patch_transformers_download()
old_from_pretrained = PreTrainedModel.from_pretrained.__func__ old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
@ -6377,6 +6428,7 @@ if __name__ == "__main__":
vars.flaskwebgui = True vars.flaskwebgui = True
FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run() FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run()
except: except:
pass
import webbrowser import webbrowser
webbrowser.open_new('http://localhost:{0}'.format(port)) webbrowser.open_new('http://localhost:{0}'.format(port))
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}" print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"

View File

@ -172,6 +172,16 @@ def num_layers(config):
#==================================================================# #==================================================================#
# Downloads huggingface checkpoints using aria2c if possible # Downloads huggingface checkpoints using aria2c if possible
#==================================================================# #==================================================================#
from flask_socketio import emit
class Send_to_socketio(object):
def write(self, bar):
print("should be emitting: ", bar, end="")
time.sleep(0.01)
try:
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
except:
pass
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers import transformers
import transformers.modeling_utils import transformers.modeling_utils
@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
done = True done = True
break break
if bar is None: if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
visited = set() visited = set()
for x in r: for x in r:
filename = x["files"][0]["path"] filename = x["files"][0]["path"]