Patch GPTJForCausalLM, if it exists, to support soft prompting

This commit is contained in:
Gnome Ann 2021-10-28 17:18:28 -04:00
parent 40b4631f6c
commit bf4e7742ac
1 changed files with 5 additions and 0 deletions

View File

@ -534,6 +534,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
cls.forward = new_causallm_forward cls.forward = new_causallm_forward
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
patch_causallm(cls) patch_causallm(cls)
try:
from transformers import GPTJForCausalLM
patch_causallm(GPTJForCausalLM)
except:
pass
# If custom GPT Neo model was chosen # If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"): if(vars.model == "NeoCustom"):