diff --git a/aiserver.py b/aiserver.py index c77caa0f..444ea843 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1005,7 +1005,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if(hasattr(self, "transformer")): inputs_embeds = self.transformer.wte(input_ids) else: - inputs_embeds = self.model.embed_tokens(input_ids) * self.model.embed_scale + inputs_embeds = self.model.embed_tokens(input_ids) if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( @@ -1013,6 +1013,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme vars.sp[shifted_input_ids.clamp(min=0)], inputs_embeds, ) + if(not hasattr(self, "transformer")): + inputs_embeds *= self.model.embed_scale kwargs['inputs_embeds'] = inputs_embeds return old_forward(self, *args, **kwargs) cls.forward = new_causallm_forward