diff --git a/aiserver.py b/aiserver.py index 08ee1d18..8779a70a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1884,8 +1884,10 @@ else: if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend - if(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "opt"): + if(vars.model_type == "opt"): tpu_mtj_backend.pad_token_id = 1 + elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"): + tpu_mtj_backend.pad_token_id = 2 tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback