Support for multiple gens per action with dynamic scan

This commit is contained in:
Gnome Ann
2021-11-17 16:17:59 -05:00
parent ffdc5fc276
commit a1bc10246c
2 changed files with 29 additions and 20 deletions

View File

@ -14,7 +14,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Dict, Set
from typing import Union, Dict, Set, List
import requests
import html
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
excluded_world_info: List[Set],
head_length: int,
):
self.any_new_entries = False
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
assert len(self.excluded_world_info) == input_ids.shape[0]
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
for i, t in enumerate(tail):
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
found -= self.excluded_world_info[i]
if(len(found) != 0):
self.any_new_entries = True
break
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
def generate(txt, minimum, maximum, found_entries=None):
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
# Store context in memory to use it for comparison with generated content
vars.lastctx = txt
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
with torch.no_grad():
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
numseqs = vars.numseqs
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max-already_generated,
min_length=minimum,
max_length=maximum-already_generated,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
minimum += diff
maximum += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1