Add TPU support for OPT-350M

The 350M model seems to have a different structure than the other ones ???
This commit is contained in:
Gnome Ann
2022-05-12 22:21:15 -04:00
parent dfa2aa7314
commit 4fa5f1cd6a
3 changed files with 9 additions and 6 deletions

View File

@ -1054,7 +1054,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
# by the number of TPU cores, and fall back to one core if an even
# number of TPU cores is not possible.
for c in (8, 6, 4, 2, 1):
if 0 == params["n_heads"] % c == params["d_model"] % c:
if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c:
params["cores_per_replica"] = c
break