From 78f52063c7d9c2d5aab41099e188e19f956f3cc8 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 2 Feb 2022 22:45:16 -0500 Subject: [PATCH] Fix XGLM soft prompts --- aiserver.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/aiserver.py b/aiserver.py index ed83fdb6..f12d53e0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -752,7 +752,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if(vars.sp is not None): shifted_input_ids = input_ids - self.config.vocab_size input_ids.clamp_(max=self.config.vocab_size-1) - inputs_embeds = self.transformer.wte(input_ids) + if(hasattr(self, "transformer")): + inputs_embeds = self.transformer.wte(input_ids) + else: + inputs_embeds = self.model.embed_tokens(input_ids) * self.embed_scale if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( @@ -765,11 +768,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme cls.forward = new_causallm_forward for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): patch_causallm(cls) - try: - from transformers import GPTJForCausalLM - patch_causallm(GPTJForCausalLM) - except: - pass + for c in ("GPTJForCausalLM", "XGLMForCausalLM"): + try: + patch_causallm(getattr(__import__("transformers"), c)) + except: + pass # Patch transformers to use our custom logit warpers