From c83642dbbcde1c7677ac40383421b07f2c012783 Mon Sep 17 00:00:00 2001 From: ebolam Date: Mon, 24 Oct 2022 20:43:34 -0400 Subject: [PATCH] Better Colab status. Disconnects due to long running something still --- tpu_mtj_backend.py | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 45073296..a4be70e4 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -56,7 +56,6 @@ import time socketio = None -queue = None params: Dict[str, Any] = {} @@ -116,16 +115,14 @@ def compiling_callback() -> None: pass -def show_spinner(): +def show_spinner(queue): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) i = 0 - global run_spinner - while run_spinner: - if i % 10 == 0: - if i % 20 == 0: - socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1") - else: - socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1") + while True: + if i % 2 == 0: + queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}]) + else: + queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}]) bar.update(i) time.sleep(0.1) i += 1 @@ -1083,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): koboldai_vars.status_message = "" -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: +def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if "pad_token_id" in kwargs: @@ -1207,10 +1204,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo print("Connecting to your Colab instance's TPU", flush=True) old_ai_busy = koboldai_vars.aibusy koboldai_vars.status_message = "Connecting to TPU" - global run_spinner - run_spinner=True - spinner = threading.Thread(target=show_spinner, args=()) - spinner.start() if os.environ.get('COLAB_TPU_ADDR', '') != '': tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab else: @@ -1218,10 +1211,31 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tpu_address = tpu_address.replace("grpc://", "") tpu_address_without_port = tpu_address.split(':', 1)[0] url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}' - requests.post(url) + def check_status(url, queue): + requests.post(url) + queue.put("Done") + + queue = multiprocessing.Queue() + spinner = multiprocessing.Process(target=check_status, args=(url, queue)) + spinner.start() + i = 0 + bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) + while True: + if not queue.empty(): + queue.get() + break + if i % 20 == 0: + # socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1") + socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}]) + elif i % 10 == 0: + # socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1") + socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}]) + bar.update(i) + time.sleep(0.1) + i += 1 + config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + tpu_address - run_spinner=False koboldai_vars.aibusy = old_ai_busy print()