From 8c594c6869e0875523819c2310d370bee4b6c552 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 21 Jun 2022 19:37:43 -0400 Subject: [PATCH] Correct the padding token for GPT-NeoX --- aiserver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 08ee1d18..8779a70a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1884,8 +1884,10 @@ else: if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend - if(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "opt"): + if(vars.model_type == "opt"): tpu_mtj_backend.pad_token_id = 1 + elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"): + tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback