mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 20:20:45 +01:00
Patch GPTJForCausalLM, if it exists, to support soft prompting
This commit is contained in:
parent
40b4631f6c
commit
bf4e7742ac
@ -534,6 +534,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||||||
cls.forward = new_causallm_forward
|
cls.forward = new_causallm_forward
|
||||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||||
patch_causallm(cls)
|
patch_causallm(cls)
|
||||||
|
try:
|
||||||
|
from transformers import GPTJForCausalLM
|
||||||
|
patch_causallm(GPTJForCausalLM)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
# If custom GPT Neo model was chosen
|
# If custom GPT Neo model was chosen
|
||||||
if(vars.model == "NeoCustom"):
|
if(vars.model == "NeoCustom"):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user