From 90b5dab721caa57162e50c2e11bb2583ed9d1c2b Mon Sep 17 00:00:00 2001 From: ebolam Date: Mon, 19 Sep 2022 08:05:08 -0400 Subject: [PATCH] Fix --- aiserver.py | 10 ++++++---- koboldai_settings.py | 3 +++ utils.py | 6 ++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/aiserver.py b/aiserver.py index c65a9e84..480bbf37 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1618,8 +1618,8 @@ def get_cluster_models(msg): # If the client settings file doesn't exist, create it # Write API key to file os.makedirs('settings', exist_ok=True) - if path.exists(get_config_filename(koboldai_vars.model_selected)): - with open(get_config_filename(koboldai_vars.model_selected), "r") as file: + if path.exists(get_config_filename(model)): + with open(get_config_filename(model), "r") as file: js = json.load(file) if 'online_model' in js: online_model = js['online_model'] @@ -1630,7 +1630,7 @@ def get_cluster_models(msg): changed=True if changed: js={} - with open(get_config_filename(koboldai_vars.model_selected), "w") as file: + with open(get_config_filename(model), "w") as file: js["apikey"] = koboldai_vars.oaiapikey file.write(json.dumps(js, indent=3)) @@ -1674,7 +1674,7 @@ def patch_transformers_download(): if bar != "": try: - print(bar, end="\r") + print(bar, end="") emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") eventlet.sleep(seconds=0) except: @@ -1712,10 +1712,12 @@ def patch_transformers_download(): desc=f"Downloading {file_name}" if file_name is not None else "Downloading", file=Send_to_socketio(), ) + koboldai_vars.total_download_chunks = total for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks if url[-11:] != 'config.json': progress.update(len(chunk)) + koboldai_vars.downloaded_chunks += len(chunk) temp_file.write(chunk) if url[-11:] != 'config.json': progress.close() diff --git a/koboldai_settings.py b/koboldai_settings.py index a68ecbbe..c91253cd 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -493,6 +493,9 @@ class model_settings(settings): if self.tqdm.format_dict['rate'] is not None: self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_layers-self.loaded_layers)/self.tqdm.format_dict['rate']))) #Setup TQDP for model downloading + elif name == "total_download_chunks" and 'tqdm' in self.__dict__: + self.tqdm.reset(total=value) + self.tqdm_progress = 0 elif name == "downloaded_chunks" and 'tqdm' in self.__dict__: if value == 0: self.tqdm.reset(total=self.total_download_chunks) diff --git a/utils.py b/utils.py index c49a76b0..d4af0b65 100644 --- a/utils.py +++ b/utils.py @@ -211,7 +211,8 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str = done = True break if bar is None: - bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) + bar = tqdm.tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) + koboldai_vars.total_download_chunks = total_length visited = set() for x in r: filename = x["files"][0]["path"] @@ -220,7 +221,8 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str = for k, v in lengths.items(): if k not in visited: lengths[k] = (v[1], v[1]) - bar.n = sum(v[0] for v in lengths.values()) + koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values()) + bar.n = koboldai_vars.downloaded_chunks bar.update() time.sleep(0.1) path = f.name