diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 83846205..c4920ae1 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1247,6 +1247,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo else: num_tensors = len(model_dict) utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") + koboldai_vars.total_layers = num_tensors if utils.num_shards is not None: utils.current_shard += 1 @@ -1261,6 +1262,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if model_spec_key is None: model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") utils.bar.update(1) + koboldai_vars.loaded_layers += 1 continue storage_key = model_dict[key].key @@ -1312,7 +1314,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo )).copy(), np.empty(params["cores_per_replica"]), ) - + + koboldai_vars.loaded_layers += 1 utils.bar.update(1) if utils.num_shards is not None and utils.current_shard < utils.num_shards: