From 3da885d408a2727bfca5d454b60ca856d4f5b19a Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 23 Jun 2022 15:02:43 -0400 Subject: [PATCH] GPT-NeoX HF model badwords fix --- aiserver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 4081bd79..ebaa0da5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -976,7 +976,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe if(vars.model_type == "opt"): vars.badwordsids = vars.badwordsids_opt - if(vars.model_type == "neox"): + if(vars.model_type == "gpt_neox"): vars.badwordsids = vars.badwordsids_neox if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): @@ -1886,7 +1886,7 @@ else: import tpu_mtj_backend if(vars.model_type == "opt"): tpu_mtj_backend.pad_token_id = 1 - elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"): + elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "gpt_neox"): tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback