Added status bar for downloading models
This commit is contained in:
parent
a0475ba049
commit
907cf74b13
52
aiserver.py
52
aiserver.py
|
@ -1305,9 +1305,60 @@ def patch_causallm(model):
|
||||||
Embedding._koboldai_patch_causallm_model = model
|
Embedding._koboldai_patch_causallm_model = model
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def patch_transformers_download():
|
||||||
|
global transformers
|
||||||
|
import copy, requests, tqdm, time
|
||||||
|
class Send_to_socketio(object):
|
||||||
|
def write(self, bar):
|
||||||
|
bar = bar.replace("\r", "")
|
||||||
|
try:
|
||||||
|
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||||
|
eventlet.sleep(seconds=0)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
def http_get(
|
||||||
|
url: str,
|
||||||
|
temp_file: transformers.utils.hub.BinaryIO,
|
||||||
|
proxies=None,
|
||||||
|
resume_size=0,
|
||||||
|
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
|
||||||
|
file_name: transformers.utils.hub.Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download remote file. Do not gobble up errors.
|
||||||
|
"""
|
||||||
|
headers = copy.deepcopy(headers)
|
||||||
|
if resume_size > 0:
|
||||||
|
headers["Range"] = f"bytes={resume_size}-"
|
||||||
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||||
|
transformers.utils.hub._raise_for_status(r)
|
||||||
|
content_length = r.headers.get("Content-Length")
|
||||||
|
total = resume_size + int(content_length) if content_length is not None else None
|
||||||
|
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||||
|
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
||||||
|
progress = tqdm.tqdm(
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
total=total,
|
||||||
|
initial=resume_size,
|
||||||
|
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
||||||
|
file=Send_to_socketio(),
|
||||||
|
)
|
||||||
|
for chunk in r.iter_content(chunk_size=1024):
|
||||||
|
if chunk: # filter out keep-alive new chunks
|
||||||
|
progress.update(len(chunk))
|
||||||
|
temp_file.write(chunk)
|
||||||
|
progress.close()
|
||||||
|
|
||||||
|
transformers.utils.hub.http_get = http_get
|
||||||
|
|
||||||
|
|
||||||
def patch_transformers():
|
def patch_transformers():
|
||||||
global transformers
|
global transformers
|
||||||
|
|
||||||
|
patch_transformers_download()
|
||||||
|
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
@classmethod
|
@classmethod
|
||||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
@ -6377,6 +6428,7 @@ if __name__ == "__main__":
|
||||||
vars.flaskwebgui = True
|
vars.flaskwebgui = True
|
||||||
FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run()
|
FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run()
|
||||||
except:
|
except:
|
||||||
|
pass
|
||||||
import webbrowser
|
import webbrowser
|
||||||
webbrowser.open_new('http://localhost:{0}'.format(port))
|
webbrowser.open_new('http://localhost:{0}'.format(port))
|
||||||
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"
|
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"
|
||||||
|
|
12
utils.py
12
utils.py
|
@ -172,6 +172,16 @@ def num_layers(config):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Downloads huggingface checkpoints using aria2c if possible
|
# Downloads huggingface checkpoints using aria2c if possible
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
from flask_socketio import emit
|
||||||
|
class Send_to_socketio(object):
|
||||||
|
def write(self, bar):
|
||||||
|
print("should be emitting: ", bar, end="")
|
||||||
|
time.sleep(0.01)
|
||||||
|
try:
|
||||||
|
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
|
@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
if bar is None:
|
if bar is None:
|
||||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
|
||||||
visited = set()
|
visited = set()
|
||||||
for x in r:
|
for x in r:
|
||||||
filename = x["files"][0]["path"]
|
filename = x["files"][0]["path"]
|
||||||
|
|
Loading…
Reference in New Issue