Always convert soft prompt to float32 if using TPU backend
TPUs do not support float16. Attempting to use a float16 soft prompt throws an error.
This commit is contained in:
parent
e068aa9f26
commit
9b8bcb5516
|
@ -2572,7 +2572,7 @@ def spRequest(filename):
|
|||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tensor
|
||||
vars.sp = np.float32(tensor)
|
||||
else:
|
||||
vars.sp = torch.from_numpy(tensor)
|
||||
|
||||
|
|
Loading…
Reference in New Issue