From 47c832fde4b03e365384179f692987d70399570f Mon Sep 17 00:00:00 2001 From: ebolam Date: Mon, 24 Oct 2022 18:52:53 -0400 Subject: [PATCH] Add Original UI status for TPU connection/loading --- aiserver.py | 7 ++++--- tpu_mtj_backend.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 38daa4d7..bfeb35b4 100644 --- a/aiserver.py +++ b/aiserver.py @@ -283,7 +283,7 @@ class Send_to_socketio(object): print('\r' + bar, end='') time.sleep(0.01) try: - emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") + socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") except: pass @@ -1799,7 +1799,7 @@ def patch_transformers_download(): if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on try: print('\r' + bar, end='') - emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") + socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") eventlet.sleep(seconds=0) except: pass @@ -3061,12 +3061,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): global tpu_mtj_backend import tpu_mtj_backend + + tpu_mtj_backend.socketio = socketio if(koboldai_vars.model == "TPUMeshTransformerGPTNeoX"): koboldai_vars.badwordsids = koboldai_vars.badwordsids_neox print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) if koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not koboldai_vars.custmodpth or not os.path.isdir(koboldai_vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(koboldai_vars.custmodpth)} is not the path to a valid folder") - import tpu_mtj_backend if(koboldai_vars.model == "TPUMeshTransformerGPTNeoX"): tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.koboldai_vars = koboldai_vars diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index e78a300e..aaa22154 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -54,6 +54,8 @@ from mesh_transformer.util import to_bf16 import time +tqdm_print = None + params: Dict[str, Any] = {} __seed = random.randrange(2**64) @@ -116,10 +118,29 @@ def show_spinner(): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) i = 0 while True: + if i % 2 == 0: + socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1") + else: + socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1") bar.update(i) time.sleep(0.1) i += 1 +class Send_to_socketio(object): + def write(self, bar): + bar = bar.replace("\r", "").replace("\n", "").replace(chr(0), "") + if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on + #logger.info(bar) + print('\r' + bar, end='') + time.sleep(0.01) + try: + socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") + except: + pass + + def flush(self): + pass + __F = TypeVar("__F", bound=Callable) __T = TypeVar("__T") @@ -991,7 +1012,10 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): } tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping) - bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint") + if socketio is None: + bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint") + else: + bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=Send_to_socketio) koboldai_vars.status_message = "Loading TPU" koboldai_vars.total_layers = tqdm_length koboldai_vars.loaded_layers = 0 @@ -1269,7 +1293,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) else: num_tensors = len(model_dict) - utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") + if socketio is None: + utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") + else: + utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio) koboldai_vars.status_message = "Loading model" koboldai_vars.loaded_layers = 0 koboldai_vars.total_layers = num_tensors