diff --git a/aiserver.py b/aiserver.py index 295a0b7d..52cd7b28 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1314,9 +1314,64 @@ def patch_causallm(model): Embedding._koboldai_patch_causallm_model = model return model +def patch_transformers_download(): + global transformers + import copy, requests, tqdm, time + class Send_to_socketio(object): + def write(self, bar): + bar = bar.replace("\r", "") + try: + print(bar, end="\r") + emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) + eventlet.sleep(seconds=0) + except: + pass + def http_get( + url: str, + temp_file: transformers.utils.hub.BinaryIO, + proxies=None, + resume_size=0, + headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None, + file_name: transformers.utils.hub.Optional[str] = None, + ): + """ + Download remote file. Do not gobble up errors. + """ + headers = copy.deepcopy(headers) + if resume_size > 0: + headers["Range"] = f"bytes={resume_size}-" + r = requests.get(url, stream=True, proxies=proxies, headers=headers) + transformers.utils.hub._raise_for_status(r) + content_length = r.headers.get("Content-Length") + total = resume_size + int(content_length) if content_length is not None else None + # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` + # and can be set using `utils.logging.enable/disable_progress_bar()` + if url[-11:] != 'config.json': + progress = tqdm.tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=total, + initial=resume_size, + desc=f"Downloading {file_name}" if file_name is not None else "Downloading", + file=Send_to_socketio(), + ) + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + if url[-11:] != 'config.json': + progress.update(len(chunk)) + temp_file.write(chunk) + if url[-11:] != 'config.json': + progress.close() + + transformers.utils.hub.http_get = http_get + def patch_transformers(): global transformers + + patch_transformers_download() + old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): @@ -6386,6 +6441,7 @@ if __name__ == "__main__": vars.flaskwebgui = True FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run() except: + pass import webbrowser webbrowser.open_new('http://localhost:{0}'.format(port)) print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}" diff --git a/utils.py b/utils.py index dc2c97d0..d20d376d 100644 --- a/utils.py +++ b/utils.py @@ -172,6 +172,16 @@ def num_layers(config): #==================================================================# # Downloads huggingface checkpoints using aria2c if possible #==================================================================# +from flask_socketio import emit +class Send_to_socketio(object): + def write(self, bar): + print("should be emitting: ", bar, end="") + time.sleep(0.01) + try: + emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) + except: + pass + def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): import transformers import transformers.modeling_utils @@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d done = True break if bar is None: - bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) + bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) visited = set() for x in r: filename = x["files"][0]["path"]