Move display of Colab link to later in the load process for TPUs

This commit is contained in:
ebolam
2022-10-25 09:41:59 -04:00
parent c83642dbbc
commit c3180fb06f
2 changed files with 22 additions and 5 deletions

View File

@@ -3081,7 +3081,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
koboldai_vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, socketio_queue=koboldai_settings.queue, **koboldai_vars.modelconfig)
tpu_mtj_backend.load_model(koboldai_vars.custmodpth,
hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
and koboldai_vars.use_colab_tpu,
socketio_queue=koboldai_settings.queue,
initial_load=initial_load, logger=logger, cloudflare=cloudflare,
**koboldai_vars.modelconfig)
#tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
koboldai_vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
tokenizer = tpu_mtj_backend.tokenizer
@@ -12240,8 +12245,11 @@ def run():
with open('cloudflare.log', 'w') as cloudflarelog:
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
logger.init_ok("Webserver", status="OK")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
if not koboldai_vars.use_colab_tpu:
# If we're using a TPU our UI will freeze during the connection to the TPU. To prevent this from showing to the user we
# delay the display of this message until after that step
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
else:
logger.init_ok("Webserver", status="OK")
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")

View File

@@ -1080,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
koboldai_vars.status_message = ""
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None:
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, cloudflare="", **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs:
@@ -1239,13 +1239,21 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
koboldai_vars.aibusy = old_ai_busy
print()
start_time = time.time()
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]
params["optimizer"] = _DummyOptimizer()
print("to line 1246 {}s".format(time.time()-start_time))
start_time = time.time()
mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
devices = jax.devices()
devices = np.array(devices[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env
if initial_load:
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
@@ -1296,6 +1304,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
from tqdm.auto import tqdm
import functools
def callback(model_dict, f, **_):
if callback.nested:
return