Show parameter count when loading GPT-NeoX in Colab TPU instance

This commit is contained in:
Gnome Ann 2022-03-15 13:55:27 -04:00
parent 9dc48b15f0
commit 9e2848e48f
1 changed files with 1 additions and 0 deletions

View File

@ -1063,6 +1063,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
return
if vars.model == "TPUMeshTransformerGPTNeoX":
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
read_neox_checkpoint(network.state, path, params)
return