mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Move display of Colab link to later in the load process for TPUs
This commit is contained in:
14
aiserver.py
14
aiserver.py
@@ -3081,7 +3081,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
koboldai_vars.allowsp = True
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
tpu_mtj_backend.load_model(koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, socketio_queue=koboldai_settings.queue, **koboldai_vars.modelconfig)
|
||||
tpu_mtj_backend.load_model(koboldai_vars.custmodpth,
|
||||
hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||
and koboldai_vars.use_colab_tpu,
|
||||
socketio_queue=koboldai_settings.queue,
|
||||
initial_load=initial_load, logger=logger, cloudflare=cloudflare,
|
||||
**koboldai_vars.modelconfig)
|
||||
#tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
|
||||
koboldai_vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
@@ -12240,8 +12245,11 @@ def run():
|
||||
with open('cloudflare.log', 'w') as cloudflarelog:
|
||||
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||
if not koboldai_vars.use_colab_tpu:
|
||||
# If we're using a TPU our UI will freeze during the connection to the TPU. To prevent this from showing to the user we
|
||||
# delay the display of this message until after that step
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||
else:
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")
|
||||
|
@@ -1080,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
|
||||
koboldai_vars.status_message = ""
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None:
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, cloudflare="", **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||
|
||||
if "pad_token_id" in kwargs:
|
||||
@@ -1239,13 +1239,21 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
koboldai_vars.aibusy = old_ai_busy
|
||||
print()
|
||||
|
||||
start_time = time.time()
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
seq = params["seq"]
|
||||
params["optimizer"] = _DummyOptimizer()
|
||||
print("to line 1246 {}s".format(time.time()-start_time))
|
||||
start_time = time.time()
|
||||
mesh_shape = (1, cores_per_replica)
|
||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||
devices = jax.devices()
|
||||
devices = np.array(devices[:cores_per_replica]).reshape(mesh_shape)
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
if initial_load:
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||
|
||||
|
||||
global shard_xmap, batch_xmap
|
||||
shard_xmap = __shard_xmap()
|
||||
@@ -1296,6 +1304,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
from tqdm.auto import tqdm
|
||||
import functools
|
||||
|
||||
|
||||
def callback(model_dict, f, **_):
|
||||
if callback.nested:
|
||||
return
|
||||
|
Reference in New Issue
Block a user