TPUMTJ: Fix loading bar

I don't know why it works but I know it works
This commit is contained in:
somebody
2023-03-13 19:48:34 -05:00
parent 938c97b75a
commit adc11fdbc9
2 changed files with 8 additions and 4 deletions

View File

@@ -712,7 +712,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
if socketio is None:
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
else:
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=utils.UIProgressBarFile())
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint", file=utils.UIProgressBarFile(socketio.emit))
koboldai_vars.status_message = "Loading TPU"
koboldai_vars.total_layers = tqdm_length
koboldai_vars.loaded_layers = 0
@@ -1021,10 +1021,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
else:
num_tensors = len(model_dict)
if socketio is None:
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
else:
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile())
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile(socketio.emit))
koboldai_vars.status_message = "Loading model"
koboldai_vars.loaded_layers = 0
koboldai_vars.total_layers = num_tensors