diff --git a/aiserver.py b/aiserver.py index 3bd15e91..e5e3ca85 100644 --- a/aiserver.py +++ b/aiserver.py @@ -534,6 +534,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): cls.forward = new_causallm_forward for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): patch_causallm(cls) + try: + from transformers import GPTJForCausalLM + patch_causallm(GPTJForCausalLM) + except: + pass # If custom GPT Neo model was chosen if(vars.model == "NeoCustom"):