From c3180fb06fc5a2931e214367d2d759b14ebea176 Mon Sep 17 00:00:00 2001 From: ebolam Date: Tue, 25 Oct 2022 09:41:59 -0400 Subject: [PATCH] Move display of Colab link to later in the load process for TPUs --- aiserver.py | 14 +++++++++++--- tpu_mtj_backend.py | 13 +++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 47e7aea0..bad05ada 100644 --- a/aiserver.py +++ b/aiserver.py @@ -3081,7 +3081,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal koboldai_vars.allowsp = True loadmodelsettings() loadsettings() - tpu_mtj_backend.load_model(koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, socketio_queue=koboldai_settings.queue, **koboldai_vars.modelconfig) + tpu_mtj_backend.load_model(koboldai_vars.custmodpth, + hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") + and koboldai_vars.use_colab_tpu, + socketio_queue=koboldai_settings.queue, + initial_load=initial_load, logger=logger, cloudflare=cloudflare, + **koboldai_vars.modelconfig) #tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig) koboldai_vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])) tokenizer = tpu_mtj_backend.tokenizer @@ -12240,8 +12245,11 @@ def run(): with open('cloudflare.log', 'w') as cloudflarelog: cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare) logger.init_ok("Webserver", status="OK") - logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}") - logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui") + if not koboldai_vars.use_colab_tpu: + # If we're using a TPU our UI will freeze during the connection to the TPU. To prevent this from showing to the user we + # delay the display of this message until after that step + logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}") + logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui") else: logger.init_ok("Webserver", status="OK") logger.message(f"Webserver has started, you can now connect to this machine at port: {port}") diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index a4be70e4..88a0ef78 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1080,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): koboldai_vars.status_message = "" -def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None: +def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, cloudflare="", **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if "pad_token_id" in kwargs: @@ -1239,13 +1239,21 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo koboldai_vars.aibusy = old_ai_busy print() + start_time = time.time() cores_per_replica = params["cores_per_replica"] seq = params["seq"] params["optimizer"] = _DummyOptimizer() + print("to line 1246 {}s".format(time.time()-start_time)) + start_time = time.time() mesh_shape = (1, cores_per_replica) - devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) + devices = jax.devices() + devices = np.array(devices[:cores_per_replica]).reshape(mesh_shape) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) maps.thread_resources.env = thread_resources_env + if initial_load: + logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}") + logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui") + global shard_xmap, batch_xmap shard_xmap = __shard_xmap() @@ -1296,6 +1304,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo from tqdm.auto import tqdm import functools + def callback(model_dict, f, **_): if callback.nested: return