diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index b36b88ba..36859d24 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1063,6 +1063,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo return if vars.model == "TPUMeshTransformerGPTNeoX": + print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") read_neox_checkpoint(network.state, path, params) return