Make sure that soft_tokens is on the correct device

This commit is contained in:
Gnome Ann 2021-11-03 16:07:50 -04:00
parent 90fd5a538a
commit 5b3ce4510f
1 changed files with 1 additions and 0 deletions

View File

@ -1471,6 +1471,7 @@ def generate(txt, min, max):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
diff = gen_in.shape[-1] - genout.shape[-1]