From 5b3ce4510f1319ce756ffdd43d3e814180936097 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 3 Nov 2021 16:07:50 -0400 Subject: [PATCH] Make sure that soft_tokens is on the correct device --- aiserver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiserver.py b/aiserver.py index 07d5c8cb..60ecf203 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1471,6 +1471,7 @@ def generate(txt, min, max): soft_tokens = torch.arange( model.config.vocab_size, model.config.vocab_size + vars.sp.shape[0], + device=genout.device, ) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) diff = gen_in.shape[-1] - genout.shape[-1]