Correct the padding token for GPT-NeoX
This commit is contained in:
parent
a7f667c34c
commit
8c594c6869
|
@ -1884,8 +1884,10 @@ else:
|
||||||
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
|
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
|
||||||
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
if(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "opt"):
|
if(vars.model_type == "opt"):
|
||||||
tpu_mtj_backend.pad_token_id = 1
|
tpu_mtj_backend.pad_token_id = 1
|
||||||
|
elif(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "neox"):
|
||||||
|
tpu_mtj_backend.pad_token_id = 2
|
||||||
tpu_mtj_backend.vars = vars
|
tpu_mtj_backend.vars = vars
|
||||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||||
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||||
|
|
Loading…
Reference in New Issue