Merge pull request #79 from VE-FORBRYDERNE/xglm-eos
Prevent transformers XGLM from stopping generation on `</s>` token
This commit is contained in:
commit
27cf59bb94
|
@ -875,6 +875,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
kwargs["logits_warper"] = new_get_logits_warper(
|
kwargs["logits_warper"] = new_get_logits_warper(
|
||||||
beams=1,
|
beams=1,
|
||||||
)
|
)
|
||||||
|
if(vars.newlinemode == "s"):
|
||||||
|
kwargs["eos_token_id"] = -1
|
||||||
|
kwargs.setdefault("pad_token_id", 2)
|
||||||
return new_sample.old_sample(self, *args, **kwargs)
|
return new_sample.old_sample(self, *args, **kwargs)
|
||||||
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||||
transformers.generation_utils.GenerationMixin.sample = new_sample
|
transformers.generation_utils.GenerationMixin.sample = new_sample
|
||||||
|
@ -2928,7 +2931,6 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||||
genout = generator(
|
genout = generator(
|
||||||
gen_in,
|
gen_in,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
min_length=minimum,
|
|
||||||
max_length=int(2e9),
|
max_length=int(2e9),
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
bad_words_ids=vars.badwordsids,
|
bad_words_ids=vars.badwordsids,
|
||||||
|
|
Loading…
Reference in New Issue