Change soft prompt implementation to a more universal one

This commit is contained in:
Gnome Ann 2022-06-21 15:03:43 -04:00
parent 0ea4fa9c87
commit 91643be10a
1 changed files with 18 additions and 19 deletions

View File

@ -1247,18 +1247,20 @@ def get_oai_models(key):
# Function to patch transformers to use our soft prompt # Function to patch transformers to use our soft prompt
def patch_causallm(cls): def patch_causallm(model):
if(getattr(cls, "_koboldai_patch_causallm_patched", False)): from torch.nn import Embedding
return if(getattr(Embedding, "_koboldai_patch_causallm_model", None)):
old_forward = cls.forward Embedding._koboldai_patch_causallm_model = model
def new_causallm_forward(self, *args, **kwargs): return model
input_ids = kwargs.get('input_ids').to(self.device) old_embedding_call = Embedding.__call__
def new_embedding_call(self, input_ids, *args, **kwargs):
if(Embedding._koboldai_patch_causallm_model.get_input_embeddings() is not self):
return old_embedding_call(self, input_ids, *args, **kwargs)
assert input_ids is not None assert input_ids is not None
kwargs['input_ids'] = None
if(vars.sp is not None): if(vars.sp is not None):
shifted_input_ids = input_ids - self.config.vocab_size shifted_input_ids = input_ids - model.config.vocab_size
input_ids.clamp_(max=self.config.vocab_size-1) input_ids.clamp_(max=model.config.vocab_size-1)
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = old_embedding_call(self, input_ids, *args, **kwargs)
if(vars.sp is not None): if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where( inputs_embeds = torch.where(
@ -1266,13 +1268,10 @@ def patch_causallm(cls):
vars.sp[shifted_input_ids.clamp(min=0)], vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds, inputs_embeds,
) )
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): return inputs_embeds
inputs_embeds *= self.model.embed_scale Embedding.__call__ = new_embedding_call
kwargs['inputs_embeds'] = inputs_embeds Embedding._koboldai_patch_causallm_model = model
return old_forward(self, *args, **kwargs) return model
cls.forward = new_causallm_forward
cls._koboldai_patch_causallm_patched = True
return cls
def patch_transformers(): def patch_transformers():
@ -1864,7 +1863,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
else: else:
model = model.to('cpu').float() model = model.to('cpu').float()
generator = model.generate generator = model.generate
patch_causallm(model.__class__) patch_causallm(model)
# Use the Generic implementation # Use the Generic implementation
else: else:
lowmem = maybe_low_cpu_mem_usage() lowmem = maybe_low_cpu_mem_usage()
@ -1998,7 +1997,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if(vars.badwordsids is vars.badwordsids_default and vars.model_type not in ("gpt2", "gpt_neo", "gptj", "xglm")): if(vars.badwordsids is vars.badwordsids_default and vars.model_type not in ("gpt2", "gpt_neo", "gptj", "xglm")):
vars.badwordsids = [[v] for k, v in tokenizer.get_vocab().items() if any(c in k for c in "<>[]")] vars.badwordsids = [[v] for k, v in tokenizer.get_vocab().items() if any(c in k for c in "<>[]")]
patch_causallm(model.__class__) patch_causallm(model)
if(vars.hascuda): if(vars.hascuda):
if(vars.usegpu): if(vars.usegpu):