Show parameter count when loading GPT-NeoX in Colab TPU instance
This commit is contained in:
parent
9dc48b15f0
commit
9e2848e48f
|
@ -1063,6 +1063,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
return
|
||||
|
||||
if vars.model == "TPUMeshTransformerGPTNeoX":
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
read_neox_checkpoint(network.state, path, params)
|
||||
return
|
||||
|
||||
|
|
Loading…
Reference in New Issue