Support for multiple gens per action with dynamic scan

This commit is contained in:
Gnome Ann 2021-11-17 16:17:59 -05:00
parent ffdc5fc276
commit a1bc10246c
2 changed files with 29 additions and 20 deletions

View File

@ -14,7 +14,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Dict, Set
from typing import Union, Dict, Set, List
import requests
import html
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
excluded_world_info: List[Set],
head_length: int,
):
self.any_new_entries = False
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
assert len(self.excluded_world_info) == input_ids.shape[0]
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
for i, t in enumerate(tail):
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
found -= self.excluded_world_info[i]
if(len(found) != 0):
self.any_new_entries = True
break
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
def generate(txt, minimum, maximum, found_entries=None):
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
# Store context in memory to use it for comparison with generated content
vars.lastctx = txt
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
with torch.no_grad():
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
numseqs = vars.numseqs
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max-already_generated,
min_length=minimum,
max_length=maximum-already_generated,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
minimum += diff
maximum += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1

View File

@ -128,7 +128,7 @@ gensettingstf = [{
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
"tooltip": "Scan the AI's output for world info keys as it's generating the output."
}]
gensettingsik =[{