Patch GPTJForCausalLM, if it exists, to support soft prompting
This commit is contained in:
parent
40b4631f6c
commit
bf4e7742ac
|
@ -534,6 +534,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
cls.forward = new_causallm_forward
|
||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||
patch_causallm(cls)
|
||||
try:
|
||||
from transformers import GPTJForCausalLM
|
||||
patch_causallm(GPTJForCausalLM)
|
||||
except:
|
||||
pass
|
||||
|
||||
# If custom GPT Neo model was chosen
|
||||
if(vars.model == "NeoCustom"):
|
||||
|
|
Loading…
Reference in New Issue