Fix generator output having the wrong length

This commit is contained in:
Gnome Ann
2021-11-03 16:10:12 -04:00
parent 5b3ce4510f
commit b8c3d8c12e

View File

@ -1440,8 +1440,8 @@ def generate(txt, min, max):
genout = generator( genout = generator(
gen_in, gen_in,
do_sample=True, do_sample=True,
min_length=min+already_generated, min_length=min,
max_length=max, max_length=max-already_generated,
repetition_penalty=vars.rep_pen, repetition_penalty=vars.rep_pen,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
@ -1460,7 +1460,7 @@ def generate(txt, min, max):
found_entries |= _found_entries found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt) txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device) encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
gen_in = torch.cat( genout = torch.cat(
( (
encoded, encoded,
genout[..., -already_generated:], genout[..., -already_generated:],
@ -1473,8 +1473,8 @@ def generate(txt, min, max):
model.config.vocab_size + vars.sp.shape[0], model.config.vocab_size + vars.sp.shape[0],
device=genout.device, device=genout.device,
) )
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) genout = torch.cat((soft_tokens[None], genout), dim=-1)
diff = gen_in.shape[-1] - genout.shape[-1] diff = genout.shape[-1] - gen_in.shape[-1]
min += diff min += diff
max += diff max += diff
gen_in = genout gen_in = genout