mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Better Colab status. Disconnects due to long running something still
This commit is contained in:
@@ -56,7 +56,6 @@ import time
|
|||||||
|
|
||||||
|
|
||||||
socketio = None
|
socketio = None
|
||||||
queue = None
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -116,16 +115,14 @@ def compiling_callback() -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def show_spinner():
|
def show_spinner(queue):
|
||||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||||
i = 0
|
i = 0
|
||||||
global run_spinner
|
while True:
|
||||||
while run_spinner:
|
if i % 2 == 0:
|
||||||
if i % 10 == 0:
|
queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}])
|
||||||
if i % 20 == 0:
|
|
||||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
|
||||||
else:
|
else:
|
||||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
|
queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}])
|
||||||
bar.update(i)
|
bar.update(i)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
i += 1
|
i += 1
|
||||||
@@ -1083,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
|
|
||||||
koboldai_vars.status_message = ""
|
koboldai_vars.status_message = ""
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||||
|
|
||||||
if "pad_token_id" in kwargs:
|
if "pad_token_id" in kwargs:
|
||||||
@@ -1207,10 +1204,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
print("Connecting to your Colab instance's TPU", flush=True)
|
print("Connecting to your Colab instance's TPU", flush=True)
|
||||||
old_ai_busy = koboldai_vars.aibusy
|
old_ai_busy = koboldai_vars.aibusy
|
||||||
koboldai_vars.status_message = "Connecting to TPU"
|
koboldai_vars.status_message = "Connecting to TPU"
|
||||||
global run_spinner
|
|
||||||
run_spinner=True
|
|
||||||
spinner = threading.Thread(target=show_spinner, args=())
|
|
||||||
spinner.start()
|
|
||||||
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
||||||
tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab
|
tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab
|
||||||
else:
|
else:
|
||||||
@@ -1218,10 +1211,31 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
tpu_address = tpu_address.replace("grpc://", "")
|
tpu_address = tpu_address.replace("grpc://", "")
|
||||||
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
||||||
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
||||||
|
def check_status(url, queue):
|
||||||
requests.post(url)
|
requests.post(url)
|
||||||
|
queue.put("Done")
|
||||||
|
|
||||||
|
queue = multiprocessing.Queue()
|
||||||
|
spinner = multiprocessing.Process(target=check_status, args=(url, queue))
|
||||||
|
spinner.start()
|
||||||
|
i = 0
|
||||||
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||||
|
while True:
|
||||||
|
if not queue.empty():
|
||||||
|
queue.get()
|
||||||
|
break
|
||||||
|
if i % 20 == 0:
|
||||||
|
# socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
||||||
|
socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, {"broadcast":True, "room":"UI_1"}])
|
||||||
|
elif i % 10 == 0:
|
||||||
|
# socketio.emit("from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, broadcast=True, room="UI_1")
|
||||||
|
socketio_queue.put(["from_server", {'cmd': 'model_load_status', 'data': "Connecting to TPU...." }, {"broadcast":True, "room":"UI_1"}])
|
||||||
|
bar.update(i)
|
||||||
|
time.sleep(0.1)
|
||||||
|
i += 1
|
||||||
|
|
||||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||||
run_spinner=False
|
|
||||||
koboldai_vars.aibusy = old_ai_busy
|
koboldai_vars.aibusy = old_ai_busy
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user