Add status bar message and status bar for TPU loading

This commit is contained in:
ebolam
2022-10-24 18:34:45 -04:00
parent 85beef719f
commit b01713875c
7 changed files with 56 additions and 4 deletions

View File

@@ -992,6 +992,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
koboldai_vars.status_message = "Loading TPU"
koboldai_vars.total_layers = tqdm_length
koboldai_vars.loaded_layers = 0
for checkpoint_layer in range(config["layers"] + 5):
if checkpoint_layer in (1, config["layers"] + 2):
@@ -1042,6 +1045,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
np.zeros(config["cores_per_replica"]),
)
bar.update(1)
koboldai_vars.loaded_layers+=1
for mk, mv in state["params"].items():
for pk, pv in mv.items():
if isinstance(pv, PlaceholderTensor):
@@ -1049,6 +1053,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
koboldai_vars.status_message = ""
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
@@ -1172,6 +1177,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
jax.host_id = jax.process_index
print("Connecting to your Colab instance's TPU", flush=True)
old_ai_busy = koboldai_vars.aibusy
koboldai_vars.status_message = "Connecting to TPU"
spinner = multiprocessing.Process(target=show_spinner, args=())
spinner.start()
if os.environ.get('COLAB_TPU_ADDR', '') != '':
@@ -1185,6 +1192,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
spinner.terminate()
koboldai_vars.aibusy = old_ai_busy
print()
cores_per_replica = params["cores_per_replica"]
@@ -1262,6 +1270,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
else:
num_tensors = len(model_dict)
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
koboldai_vars.status_message = "Loading model"
koboldai_vars.loaded_layers = 0
koboldai_vars.total_layers = num_tensors
if utils.num_shards is not None:
@@ -1362,6 +1372,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
utils.bar.close()
utils.bar = None
koboldai_vars.status_message = ""
callback.nested = False
if isinstance(f, zipfile.ZipExtFile):
f.close()