diff --git a/prompt_tuner.py b/prompt_tuner.py index 1c381f2b..b0886741 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -493,7 +493,7 @@ class TrainerBase(abc.ABC): params["max_batch_size"] = 2048 with tokenizer._kai_no_prefix(): params["eos_token"] = ( - [50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else tokenizer.encode(model_config.eos_token_id) + [50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else [model_config.eos_token_id] ) params["seq"] = 2048 self.data.params = params