Automatically calculate badwords and pad_token_id

This commit is contained in:
Gnome Ann
2022-06-21 14:35:52 -04:00
parent ea7d278ff4
commit 0ea4fa9c87
2 changed files with 13 additions and 6 deletions

View File

@ -1018,7 +1018,12 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs:
pad_token_id = kwargs["pad_token_id"]
elif "eos_token_id" in kwargs:
pad_token_id = kwargs["eos_token_id"]
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
vars.sampler_order = utils.default_sampler_order.copy()