diff --git a/aiserver.py b/aiserver.py index f48981b0..163e0bc9 100644 --- a/aiserver.py +++ b/aiserver.py @@ -526,7 +526,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): inputs_embeds = self.transformer.wte(input_ids) input_ids -= self.config.vocab_size if(vars.sp is not None): - vars.sp = vars.sp.to(inputs_embeds.device) + vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( (input_ids >= 0)[:, :, None], vars.sp[input_ids.clamp(min=0)],