mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Colab Fix
This commit is contained in:
@@ -118,7 +118,8 @@ def compiling_callback() -> None:
|
|||||||
def show_spinner():
|
def show_spinner():
|
||||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||||
i = 0
|
i = 0
|
||||||
while True:
|
global run_spinner
|
||||||
|
while run_spinner=True:
|
||||||
print("Sending to client")
|
print("Sending to client")
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
||||||
@@ -1205,6 +1206,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
print("Connecting to your Colab instance's TPU", flush=True)
|
print("Connecting to your Colab instance's TPU", flush=True)
|
||||||
old_ai_busy = koboldai_vars.aibusy
|
old_ai_busy = koboldai_vars.aibusy
|
||||||
koboldai_vars.status_message = "Connecting to TPU"
|
koboldai_vars.status_message = "Connecting to TPU"
|
||||||
|
global run_spinner
|
||||||
|
run_spinner=True
|
||||||
spinner = threading.Thread(target=show_spinner, args=())
|
spinner = threading.Thread(target=show_spinner, args=())
|
||||||
spinner.start()
|
spinner.start()
|
||||||
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
||||||
@@ -1217,7 +1220,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
requests.post(url)
|
requests.post(url)
|
||||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||||
spinner.terminate()
|
run_spinner=False
|
||||||
koboldai_vars.aibusy = old_ai_busy
|
koboldai_vars.aibusy = old_ai_busy
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user