mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add status bar message and status bar for TPU loading
This commit is contained in:
@@ -992,6 +992,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
|
||||
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
|
||||
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
|
||||
koboldai_vars.status_message = "Loading TPU"
|
||||
koboldai_vars.total_layers = tqdm_length
|
||||
koboldai_vars.loaded_layers = 0
|
||||
|
||||
for checkpoint_layer in range(config["layers"] + 5):
|
||||
if checkpoint_layer in (1, config["layers"] + 2):
|
||||
@@ -1042,6 +1045,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
np.zeros(config["cores_per_replica"]),
|
||||
)
|
||||
bar.update(1)
|
||||
koboldai_vars.loaded_layers+=1
|
||||
for mk, mv in state["params"].items():
|
||||
for pk, pv in mv.items():
|
||||
if isinstance(pv, PlaceholderTensor):
|
||||
@@ -1049,6 +1053,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
|
||||
koboldai_vars.status_message = ""
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||
@@ -1172,6 +1177,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
jax.host_id = jax.process_index
|
||||
|
||||
print("Connecting to your Colab instance's TPU", flush=True)
|
||||
old_ai_busy = koboldai_vars.aibusy
|
||||
koboldai_vars.status_message = "Connecting to TPU"
|
||||
spinner = multiprocessing.Process(target=show_spinner, args=())
|
||||
spinner.start()
|
||||
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
||||
@@ -1185,6 +1192,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||
spinner.terminate()
|
||||
koboldai_vars.aibusy = old_ai_busy
|
||||
print()
|
||||
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
@@ -1262,6 +1270,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
else:
|
||||
num_tensors = len(model_dict)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
koboldai_vars.status_message = "Loading model"
|
||||
koboldai_vars.loaded_layers = 0
|
||||
koboldai_vars.total_layers = num_tensors
|
||||
|
||||
if utils.num_shards is not None:
|
||||
@@ -1362,6 +1372,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
koboldai_vars.status_message = ""
|
||||
callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
Reference in New Issue
Block a user