diff --git a/aiserver.py b/aiserver.py index 868a5755..800e3449 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1210,8 +1210,37 @@ def get_oai_models(key): print("{0}ERROR!{1}".format(colors.RED, colors.END)) print(req.json()) emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) - - + + +# Function to patch transformers to use our soft prompt +def patch_causallm(cls): + if(getattr(cls, "_koboldai_patch_causallm_patched", False)): + return + old_forward = cls.forward + def new_causallm_forward(self, *args, **kwargs): + input_ids = kwargs.get('input_ids').to(self.device) + assert input_ids is not None + kwargs['input_ids'] = None + if(vars.sp is not None): + shifted_input_ids = input_ids - self.config.vocab_size + input_ids.clamp_(max=self.config.vocab_size-1) + inputs_embeds = self.get_input_embeddings()(input_ids) + if(vars.sp is not None): + vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) + inputs_embeds = torch.where( + (shifted_input_ids >= 0)[..., None], + vars.sp[shifted_input_ids.clamp(min=0)], + inputs_embeds, + ) + if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): + inputs_embeds *= self.model.embed_scale + kwargs['inputs_embeds'] = inputs_embeds + return old_forward(self, *args, **kwargs) + cls.forward = new_causallm_forward + cls._koboldai_patch_causallm_patched = True + return cls + + def patch_transformers(): global transformers old_from_pretrained = PreTrainedModel.from_pretrained.__func__ @@ -1259,42 +1288,6 @@ def patch_transformers(): return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() XGLMSinusoidalPositionalEmbedding.forward = new_forward - # Patch transformers to use our soft prompt - def patch_causallm(cls): - old_forward = cls.forward - def new_causallm_forward(self, *args, **kwargs): - input_ids = kwargs.get('input_ids').to(self.device) - assert input_ids is not None - kwargs['input_ids'] = None - if(vars.sp is not None): - shifted_input_ids = input_ids - self.config.vocab_size - input_ids.clamp_(max=self.config.vocab_size-1) - if(hasattr(self, "transformer")): - inputs_embeds = self.transformer.wte(input_ids) - elif(not hasattr(self.model, "decoder")): - inputs_embeds = self.model.embed_tokens(input_ids) - else: - inputs_embeds = self.model.decoder.embed_tokens(input_ids) - if(vars.sp is not None): - vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) - inputs_embeds = torch.where( - (shifted_input_ids >= 0)[..., None], - vars.sp[shifted_input_ids.clamp(min=0)], - inputs_embeds, - ) - if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): - inputs_embeds *= self.model.embed_scale - kwargs['inputs_embeds'] = inputs_embeds - return old_forward(self, *args, **kwargs) - cls.forward = new_causallm_forward - for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): - patch_causallm(cls) - for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"): - try: - patch_causallm(getattr(__import__("transformers"), c)) - except: - pass - # Fix a bug in OPTForCausalLM where self.lm_head is the wrong size if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")): @@ -1796,6 +1789,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" else: model = model.to('cpu').float() generator = model.generate + patch_causallm(model.__class__) # Use the Generic implementation else: lowmem = maybe_low_cpu_mem_usage() @@ -1923,7 +1917,9 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" for filename in filenames: shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename)) shutil.rmtree("cache/") - + + patch_causallm(model.__class__) + if(vars.hascuda): if(vars.usegpu): vars.modeldim = get_hidden_size_from_model(model)