Loading a sharded model will now display only one progress bar

This commit is contained in:
Gnome Ann
2022-05-13 23:32:16 -04:00
parent f9f1a5f3a9
commit 0c5ca5261e
3 changed files with 67 additions and 2 deletions

View File

@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import functools
def callback(model_dict, f, **_):
if callback.nested:
return
callback.nested = True
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
current_offset = 0
if utils.current_shard == 0:
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
if utils.num_shards is None or utils.current_shard == 0:
if utils.num_shards is not None:
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
else:
num_tensors = len(model_dict)
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
if utils.num_shards is not None:
utils.current_shard += 1
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
# Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because
@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
# tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
utils.bar.update(1)
continue
storage_key = model_dict[key].key
@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
np.empty(params["cores_per_replica"]),
)
utils.bar.update(1)
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
return
@ -1250,9 +1264,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
error = f"{mk} {pk} could not be found in the model checkpoint"
print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error)
except:
import traceback
traceback.print_exc()
finally:
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
utils.bar.close()
utils.bar = None
callback.nested = False
if isinstance(f, zipfile.ZipExtFile):
f.close()
callback.nested = False
if os.path.isdir(vars.model.replace('/', '_')):
import shutil