From 91643be10a49b5b9362dfcef86576f57e99ba383 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 21 Jun 2022 15:03:43 -0400 Subject: [PATCH] Change soft prompt implementation to a more universal one --- aiserver.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/aiserver.py b/aiserver.py index dc16660a..8f543865 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1247,18 +1247,20 @@ def get_oai_models(key): # Function to patch transformers to use our soft prompt -def patch_causallm(cls): - if(getattr(cls, "_koboldai_patch_causallm_patched", False)): - return - old_forward = cls.forward - def new_causallm_forward(self, *args, **kwargs): - input_ids = kwargs.get('input_ids').to(self.device) +def patch_causallm(model): + from torch.nn import Embedding + if(getattr(Embedding, "_koboldai_patch_causallm_model", None)): + Embedding._koboldai_patch_causallm_model = model + return model + old_embedding_call = Embedding.__call__ + def new_embedding_call(self, input_ids, *args, **kwargs): + if(Embedding._koboldai_patch_causallm_model.get_input_embeddings() is not self): + return old_embedding_call(self, input_ids, *args, **kwargs) assert input_ids is not None - kwargs['input_ids'] = None if(vars.sp is not None): - shifted_input_ids = input_ids - self.config.vocab_size - input_ids.clamp_(max=self.config.vocab_size-1) - inputs_embeds = self.get_input_embeddings()(input_ids) + shifted_input_ids = input_ids - model.config.vocab_size + input_ids.clamp_(max=model.config.vocab_size-1) + inputs_embeds = old_embedding_call(self, input_ids, *args, **kwargs) if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( @@ -1266,13 +1268,10 @@ def patch_causallm(cls): vars.sp[shifted_input_ids.clamp(min=0)], inputs_embeds, ) - if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): - inputs_embeds *= self.model.embed_scale - kwargs['inputs_embeds'] = inputs_embeds - return old_forward(self, *args, **kwargs) - cls.forward = new_causallm_forward - cls._koboldai_patch_causallm_patched = True - return cls + return inputs_embeds + Embedding.__call__ = new_embedding_call + Embedding._koboldai_patch_causallm_model = model + return model def patch_transformers(): @@ -1864,7 +1863,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal else: model = model.to('cpu').float() generator = model.generate - patch_causallm(model.__class__) + patch_causallm(model) # Use the Generic implementation else: lowmem = maybe_low_cpu_mem_usage() @@ -1998,7 +1997,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal if(vars.badwordsids is vars.badwordsids_default and vars.model_type not in ("gpt2", "gpt_neo", "gptj", "xglm")): vars.badwordsids = [[v] for k, v in tokenizer.get_vocab().items() if any(c in k for c in "<>[]")] - patch_causallm(model.__class__) + patch_causallm(model) if(vars.hascuda): if(vars.usegpu):