diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 0441a14c..316dcc69 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -945,7 +945,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo "n_vocab": 50432, "n_vocab_padding": 0, "norm": "doublelayernorm", - "pe": "rotary", + "pe": "neox_rotary", "pe_rotary_dims": 24, "seq": 2048, "cores_per_replica": 8,