From 55771c1edfea03782f309e9a709c209b73b528c7 Mon Sep 17 00:00:00 2001 From: ebolam Date: Sun, 25 Sep 2022 20:03:15 -0400 Subject: [PATCH] Giving proper TPU load progress in UI --- tpu_mtj_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 83846205..c4920ae1 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1247,6 +1247,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo else: num_tensors = len(model_dict) utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") + koboldai_vars.total_layers = num_tensors if utils.num_shards is not None: utils.current_shard += 1 @@ -1261,6 +1262,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if model_spec_key is None: model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") utils.bar.update(1) + koboldai_vars.loaded_layers += 1 continue storage_key = model_dict[key].key @@ -1312,7 +1314,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo )).copy(), np.empty(params["cores_per_replica"]), ) - + + koboldai_vars.loaded_layers += 1 utils.bar.update(1) if utils.num_shards is not None and utils.current_shard < utils.num_shards: