mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix generator output having the wrong length
This commit is contained in:
10
aiserver.py
10
aiserver.py
@ -1440,8 +1440,8 @@ def generate(txt, min, max):
|
|||||||
genout = generator(
|
genout = generator(
|
||||||
gen_in,
|
gen_in,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
min_length=min+already_generated,
|
min_length=min,
|
||||||
max_length=max,
|
max_length=max-already_generated,
|
||||||
repetition_penalty=vars.rep_pen,
|
repetition_penalty=vars.rep_pen,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@ -1460,7 +1460,7 @@ def generate(txt, min, max):
|
|||||||
found_entries |= _found_entries
|
found_entries |= _found_entries
|
||||||
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt)
|
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt)
|
||||||
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
||||||
gen_in = torch.cat(
|
genout = torch.cat(
|
||||||
(
|
(
|
||||||
encoded,
|
encoded,
|
||||||
genout[..., -already_generated:],
|
genout[..., -already_generated:],
|
||||||
@ -1473,8 +1473,8 @@ def generate(txt, min, max):
|
|||||||
model.config.vocab_size + vars.sp.shape[0],
|
model.config.vocab_size + vars.sp.shape[0],
|
||||||
device=genout.device,
|
device=genout.device,
|
||||||
)
|
)
|
||||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
genout = torch.cat((soft_tokens[None], genout), dim=-1)
|
||||||
diff = gen_in.shape[-1] - genout.shape[-1]
|
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||||
min += diff
|
min += diff
|
||||||
max += diff
|
max += diff
|
||||||
gen_in = genout
|
gen_in = genout
|
||||||
|
Reference in New Issue
Block a user