GPT-NeoX HF model badwords fix

This commit is contained in:
vfbd 2022-06-23 15:02:43 -04:00
parent 1d41966d88
commit 3da885d408
1 changed files with 2 additions and 2 deletions

View File

@ -976,7 +976,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe
if(vars.model_type == "opt"): if(vars.model_type == "opt"):
vars.badwordsids = vars.badwordsids_opt vars.badwordsids = vars.badwordsids_opt
if(vars.model_type == "neox"): if(vars.model_type == "gpt_neox"):
vars.badwordsids = vars.badwordsids_neox vars.badwordsids = vars.badwordsids_neox
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
@ -1886,7 +1886,7 @@ else:
import tpu_mtj_backend import tpu_mtj_backend
if(vars.model_type == "opt"): if(vars.model_type == "opt"):
tpu_mtj_backend.pad_token_id = 1 tpu_mtj_backend.pad_token_id = 1
elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"): elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "gpt_neox"):
tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.pad_token_id = 2
tpu_mtj_backend.vars = vars tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback