Always convert soft prompt to float32 if using TPU backend

TPUs do not support float16. Attempting to use a float16 soft prompt
throws an error.
This commit is contained in:
Gnome Ann 2021-11-21 18:22:10 -05:00
parent e068aa9f26
commit 9b8bcb5516
1 changed files with 1 additions and 1 deletions

View File

@ -2572,7 +2572,7 @@ def spRequest(filename):
-1, -1,
tpu_mtj_backend.params["d_model"], tpu_mtj_backend.params["d_model"],
) )
vars.sp = tensor vars.sp = np.float32(tensor)
else: else:
vars.sp = torch.from_numpy(tensor) vars.sp = torch.from_numpy(tensor)