mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-06 04:14:18 +01:00
Patch GPTJForCausalLM, if it exists, to support soft prompting
This commit is contained in:
parent
40b4631f6c
commit
bf4e7742ac
@ -534,6 +534,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
cls.forward = new_causallm_forward
|
||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||
patch_causallm(cls)
|
||||
try:
|
||||
from transformers import GPTJForCausalLM
|
||||
patch_causallm(GPTJForCausalLM)
|
||||
except:
|
||||
pass
|
||||
|
||||
# If custom GPT Neo model was chosen
|
||||
if(vars.model == "NeoCustom"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user