mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Colab Fix
This commit is contained in:
@@ -118,7 +118,8 @@ def compiling_callback() -> None:
|
||||
def show_spinner():
|
||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||
i = 0
|
||||
while True:
|
||||
global run_spinner
|
||||
while run_spinner=True:
|
||||
print("Sending to client")
|
||||
if i % 2 == 0:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': "Connecting to TPU..." }, broadcast=True, room="UI_1")
|
||||
@@ -1205,6 +1206,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
print("Connecting to your Colab instance's TPU", flush=True)
|
||||
old_ai_busy = koboldai_vars.aibusy
|
||||
koboldai_vars.status_message = "Connecting to TPU"
|
||||
global run_spinner
|
||||
run_spinner=True
|
||||
spinner = threading.Thread(target=show_spinner, args=())
|
||||
spinner.start()
|
||||
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
||||
@@ -1217,7 +1220,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
requests.post(url)
|
||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||
spinner.terminate()
|
||||
run_spinner=False
|
||||
koboldai_vars.aibusy = old_ai_busy
|
||||
print()
|
||||
|
||||
|
Reference in New Issue
Block a user