Patch GPTJForCausalLM, if it exists, to support soft prompting
This commit is contained in:
parent
40b4631f6c
commit
bf4e7742ac
|
@ -534,6 +534,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
cls.forward = new_causallm_forward
|
cls.forward = new_causallm_forward
|
||||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||||
patch_causallm(cls)
|
patch_causallm(cls)
|
||||||
|
try:
|
||||||
|
from transformers import GPTJForCausalLM
|
||||||
|
patch_causallm(GPTJForCausalLM)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
# If custom GPT Neo model was chosen
|
# If custom GPT Neo model was chosen
|
||||||
if(vars.model == "NeoCustom"):
|
if(vars.model == "NeoCustom"):
|
||||||
|
|
Loading…
Reference in New Issue