Correct the padding token for GPT-NeoX

This commit is contained in:
Gnome Ann 2022-06-21 19:37:43 -04:00
parent a7f667c34c
commit 8c594c6869
1 changed files with 3 additions and 1 deletions

View File

@ -1884,8 +1884,10 @@ else:
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)): if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend import tpu_mtj_backend
if(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "opt"): if(vars.model_type == "opt"):
tpu_mtj_backend.pad_token_id = 1 tpu_mtj_backend.pad_token_id = 1
elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"):
tpu_mtj_backend.pad_token_id = 2
tpu_mtj_backend.vars = vars tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback