From c1df2c786235041b4897d0797dcec6bc4fb18650 Mon Sep 17 00:00:00 2001 From: ebolam Date: Mon, 24 Oct 2022 19:32:17 -0400 Subject: [PATCH] Colab Fix --- tpu_mtj_backend.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index cf2cbbe2..0f30367a 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -118,7 +118,8 @@ def compiling_callback() -> None: def show_spinner(): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) i = 0 - while True: + global run_spinner + while run_spinner=True: print("Sending to client") if i % 2 == 0: socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1") @@ -1205,6 +1206,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo print("Connecting to your Colab instance's TPU", flush=True) old_ai_busy = koboldai_vars.aibusy koboldai_vars.status_message = "Connecting to TPU" + global run_spinner + run_spinner=True spinner = threading.Thread(target=show_spinner, args=()) spinner.start() if os.environ.get('COLAB_TPU_ADDR', '') != '': @@ -1217,7 +1220,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo requests.post(url) config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + tpu_address - spinner.terminate() + run_spinner=False koboldai_vars.aibusy = old_ai_busy print()