mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Move display of Colab link to later in the load process for TPUs
This commit is contained in:
14
aiserver.py
14
aiserver.py
@@ -3081,7 +3081,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
koboldai_vars.allowsp = True
|
koboldai_vars.allowsp = True
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
loadsettings()
|
loadsettings()
|
||||||
tpu_mtj_backend.load_model(koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, socketio_queue=koboldai_settings.queue, **koboldai_vars.modelconfig)
|
tpu_mtj_backend.load_model(koboldai_vars.custmodpth,
|
||||||
|
hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||||
|
and koboldai_vars.use_colab_tpu,
|
||||||
|
socketio_queue=koboldai_settings.queue,
|
||||||
|
initial_load=initial_load, logger=logger, cloudflare=cloudflare,
|
||||||
|
**koboldai_vars.modelconfig)
|
||||||
#tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
|
#tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
|
||||||
koboldai_vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
|
koboldai_vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
|
||||||
tokenizer = tpu_mtj_backend.tokenizer
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
@@ -12240,8 +12245,11 @@ def run():
|
|||||||
with open('cloudflare.log', 'w') as cloudflarelog:
|
with open('cloudflare.log', 'w') as cloudflarelog:
|
||||||
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
|
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
|
||||||
logger.init_ok("Webserver", status="OK")
|
logger.init_ok("Webserver", status="OK")
|
||||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
if not koboldai_vars.use_colab_tpu:
|
||||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
# If we're using a TPU our UI will freeze during the connection to the TPU. To prevent this from showing to the user we
|
||||||
|
# delay the display of this message until after that step
|
||||||
|
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||||
|
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||||
else:
|
else:
|
||||||
logger.init_ok("Webserver", status="OK")
|
logger.init_ok("Webserver", status="OK")
|
||||||
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")
|
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")
|
||||||
|
@@ -1080,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
|
|
||||||
koboldai_vars.status_message = ""
|
koboldai_vars.status_message = ""
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, cloudflare="", **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||||
|
|
||||||
if "pad_token_id" in kwargs:
|
if "pad_token_id" in kwargs:
|
||||||
@@ -1239,13 +1239,21 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
koboldai_vars.aibusy = old_ai_busy
|
koboldai_vars.aibusy = old_ai_busy
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
cores_per_replica = params["cores_per_replica"]
|
cores_per_replica = params["cores_per_replica"]
|
||||||
seq = params["seq"]
|
seq = params["seq"]
|
||||||
params["optimizer"] = _DummyOptimizer()
|
params["optimizer"] = _DummyOptimizer()
|
||||||
|
print("to line 1246 {}s".format(time.time()-start_time))
|
||||||
|
start_time = time.time()
|
||||||
mesh_shape = (1, cores_per_replica)
|
mesh_shape = (1, cores_per_replica)
|
||||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
devices = jax.devices()
|
||||||
|
devices = np.array(devices[:cores_per_replica]).reshape(mesh_shape)
|
||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
if initial_load:
|
||||||
|
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||||
|
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||||
|
|
||||||
|
|
||||||
global shard_xmap, batch_xmap
|
global shard_xmap, batch_xmap
|
||||||
shard_xmap = __shard_xmap()
|
shard_xmap = __shard_xmap()
|
||||||
@@ -1296,6 +1304,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
||||||
def callback(model_dict, f, **_):
|
def callback(model_dict, f, **_):
|
||||||
if callback.nested:
|
if callback.nested:
|
||||||
return
|
return
|
||||||
|
Reference in New Issue
Block a user