diff --git a/aiserver.py b/aiserver.py index 4081bd79..ebaa0da5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -976,7 +976,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe if(vars.model_type == "opt"): vars.badwordsids = vars.badwordsids_opt - if(vars.model_type == "neox"): + if(vars.model_type == "gpt_neox"): vars.badwordsids = vars.badwordsids_neox if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): @@ -1886,7 +1886,7 @@ else: import tpu_mtj_backend if(vars.model_type == "opt"): tpu_mtj_backend.pad_token_id = 1 - elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"): + elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "gpt_neox"): tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback