Merge pull request #130 from VE-FORBRYDERNE/neox-badwords
GPT-NeoX HF model badwords fix
This commit is contained in:
commit
ec0bc1cc17
|
@ -976,7 +976,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe
|
||||||
|
|
||||||
if(vars.model_type == "opt"):
|
if(vars.model_type == "opt"):
|
||||||
vars.badwordsids = vars.badwordsids_opt
|
vars.badwordsids = vars.badwordsids_opt
|
||||||
if(vars.model_type == "neox"):
|
if(vars.model_type == "gpt_neox"):
|
||||||
vars.badwordsids = vars.badwordsids_neox
|
vars.badwordsids = vars.badwordsids_neox
|
||||||
|
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
|
@ -1886,7 +1886,7 @@ else:
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
if(vars.model_type == "opt"):
|
if(vars.model_type == "opt"):
|
||||||
tpu_mtj_backend.pad_token_id = 1
|
tpu_mtj_backend.pad_token_id = 1
|
||||||
elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"):
|
elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "gpt_neox"):
|
||||||
tpu_mtj_backend.pad_token_id = 2
|
tpu_mtj_backend.pad_token_id = 2
|
||||||
tpu_mtj_backend.vars = vars
|
tpu_mtj_backend.vars = vars
|
||||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||||
|
|
Loading…
Reference in New Issue