From adc11fdbc9486781fa3eb9d27d7d997982f6fcb7 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 13 Mar 2023 19:48:34 -0500 Subject: [PATCH] TPUMTJ: Fix loading bar I don't know why it works but I know it works --- tpu_mtj_backend.py | 5 +++-- utils.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 55f382f7..b83ebaac 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -712,7 +712,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): if socketio is None: bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint") else: - bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=utils.UIProgressBarFile()) + bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=utils.UIProgressBarFile(socketio.emit)) koboldai_vars.status_message = "Loading TPU" koboldai_vars.total_layers = tqdm_length koboldai_vars.loaded_layers = 0 @@ -1021,10 +1021,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) else: num_tensors = len(model_dict) + if socketio is None: utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") else: - utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile()) + utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile(socketio.emit)) koboldai_vars.status_message = "Loading model" koboldai_vars.loaded_layers = 0 koboldai_vars.total_layers = num_tensors diff --git a/utils.py b/utils.py index 651269af..7c141d4c 100644 --- a/utils.py +++ b/utils.py @@ -630,6 +630,8 @@ def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[s class UIProgressBarFile(object): """Write TQDM progress to the UI.""" + def __init__(self, emit_func=emit) -> None: + self.emit_func = emit_func def write(self, bar): bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "") @@ -638,8 +640,9 @@ class UIProgressBarFile(object): print('\r' + bar, end='') time.sleep(0.01) try: - emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") - except: + self.emit_func('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") + except Exception as e: + print(e) pass def flush(self):