From 511817132a3d6d6257e175e2828abc468dc32260 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 28 Oct 2021 15:39:59 -0400 Subject: [PATCH] Don't change the shape of transformer.wte --- aiserver.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/aiserver.py b/aiserver.py index 163e0bc9..8b842c6d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -515,16 +515,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): def patch_causallm(cls): old_forward = cls.forward def new_causallm_forward(self, *args, **kwargs): - num_embeddings = self.config.vocab_size - if(vars.sp is not None): - num_embeddings += vars.sp.shape[0] - if(self.transformer.wte.num_embeddings != num_embeddings): - self.resize_token_embeddings(num_embeddings) input_ids = kwargs.get('input_ids').to(self.device) assert input_ids is not None kwargs['input_ids'] = None - inputs_embeds = self.transformer.wte(input_ids) - input_ids -= self.config.vocab_size + inputs_embeds = self.transformer.wte(input_ids.clamp(max=self.config.vocab_size-1)) + input_ids = input_ids - self.config.vocab_size # Don't use the -= operator here, you'll get a cryptic error message if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where(