mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Added status bar for downloading models
This commit is contained in:
12
utils.py
12
utils.py
@ -172,6 +172,16 @@ def num_layers(config):
|
||||
#==================================================================#
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
#==================================================================#
|
||||
from flask_socketio import emit
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
print("should be emitting: ", bar, end="")
|
||||
time.sleep(0.01)
|
||||
try:
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
done = True
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
|
Reference in New Issue
Block a user