Fix for getting "model download status" when downloading config to figure out layer counts

This commit is contained in:
ebolam 2022-07-25 18:29:14 -04:00
parent 907cf74b13
commit 12acb50ee0
1 changed files with 15 additions and 11 deletions

View File

@ -1312,6 +1312,7 @@ def patch_transformers_download():
def write(self, bar): def write(self, bar):
bar = bar.replace("\r", "") bar = bar.replace("\r", "")
try: try:
print(bar, end="\r")
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
eventlet.sleep(seconds=0) eventlet.sleep(seconds=0)
except: except:
@ -1336,20 +1337,23 @@ def patch_transformers_download():
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()` # and can be set using `utils.logging.enable/disable_progress_bar()`
progress = tqdm.tqdm( if url[-11:] != 'config.json':
unit="B", progress = tqdm.tqdm(
unit_scale=True, unit="B",
unit_divisor=1024, unit_scale=True,
total=total, unit_divisor=1024,
initial=resume_size, total=total,
desc=f"Downloading {file_name}" if file_name is not None else "Downloading", initial=resume_size,
file=Send_to_socketio(), desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
) file=Send_to_socketio(),
)
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) if url[-11:] != 'config.json':
progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() if url[-11:] != 'config.json':
progress.close()
transformers.utils.hub.http_get = http_get transformers.utils.hub.http_get = http_get