mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Better Colab status. Disconnects due to long running something still
This commit is contained in:
@@ -56,7 +56,6 @@ import time
|
||||
|
||||
|
||||
socketio = None
|
||||
queue = None
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
@@ -116,16 +115,14 @@ def compiling_callback() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def show_spinner():
|
||||
def show_spinner(queue):
|
||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||
i = 0
|
||||
global run_spinner
|
||||
while run_spinner:
|
||||
if i % 10 == 0:
|
||||
if i % 20 == 0:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
||||
else:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
|
||||
while True:
|
||||
if i % 2 == 0:
|
||||
queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}])
|
||||
else:
|
||||
queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}])
|
||||
bar.update(i)
|
||||
time.sleep(0.1)
|
||||
i += 1
|
||||
@@ -1083,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
|
||||
koboldai_vars.status_message = ""
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||
|
||||
if "pad_token_id" in kwargs:
|
||||
@@ -1207,10 +1204,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
print("Connecting to your Colab instance's TPU", flush=True)
|
||||
old_ai_busy = koboldai_vars.aibusy
|
||||
koboldai_vars.status_message = "Connecting to TPU"
|
||||
global run_spinner
|
||||
run_spinner=True
|
||||
spinner = threading.Thread(target=show_spinner, args=())
|
||||
spinner.start()
|
||||
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
||||
tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab
|
||||
else:
|
||||
@@ -1218,10 +1211,31 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
tpu_address = tpu_address.replace("grpc://", "")
|
||||
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
||||
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
||||
requests.post(url)
|
||||
def check_status(url, queue):
|
||||
requests.post(url)
|
||||
queue.put("Done")
|
||||
|
||||
queue = multiprocessing.Queue()
|
||||
spinner = multiprocessing.Process(target=check_status, args=(url, queue))
|
||||
spinner.start()
|
||||
i = 0
|
||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||
while True:
|
||||
if not queue.empty():
|
||||
queue.get()
|
||||
break
|
||||
if i % 20 == 0:
|
||||
# socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
||||
socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}])
|
||||
elif i % 10 == 0:
|
||||
# socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
|
||||
socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}])
|
||||
bar.update(i)
|
||||
time.sleep(0.1)
|
||||
i += 1
|
||||
|
||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||
run_spinner=False
|
||||
koboldai_vars.aibusy = old_ai_busy
|
||||
print()
|
||||
|
||||
|
Reference in New Issue
Block a user