Better Colab status. Disconnects due to long running something still

This commit is contained in:
ebolam
2022-10-24 20:43:34 -04:00
parent 405578f2b3
commit c83642dbbc

View File

@@ -56,7 +56,6 @@ import time
socketio = None
queue = None
params: Dict[str, Any] = {}
@@ -116,16 +115,14 @@ def compiling_callback() -> None:
pass
def show_spinner():
def show_spinner(queue):
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
i = 0
global run_spinner
while run_spinner:
if i % 10 == 0:
if i % 20 == 0:
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
else:
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
while True:
if i % 2 == 0:
queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}])
else:
queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}])
bar.update(i)
time.sleep(0.1)
i += 1
@@ -1083,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
koboldai_vars.status_message = ""
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs:
@@ -1207,10 +1204,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
print("Connecting to your Colab instance's TPU", flush=True)
old_ai_busy = koboldai_vars.aibusy
koboldai_vars.status_message = "Connecting to TPU"
global run_spinner
run_spinner=True
spinner = threading.Thread(target=show_spinner, args=())
spinner.start()
if os.environ.get('COLAB_TPU_ADDR', '') != '':
tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab
else:
@@ -1218,10 +1211,31 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
requests.post(url)
def check_status(url, queue):
requests.post(url)
queue.put("Done")
queue = multiprocessing.Queue()
spinner = multiprocessing.Process(target=check_status, args=(url, queue))
spinner.start()
i = 0
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
while True:
if not queue.empty():
queue.get()
break
if i % 20 == 0:
# socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}])
elif i % 10 == 0:
# socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}])
bar.update(i)
time.sleep(0.1)
i += 1
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
run_spinner=False
koboldai_vars.aibusy = old_ai_busy
print()