Support for multiple gens per action with dynamic scan

This commit is contained in:
Gnome Ann 2021-11-17 16:17:59 -05:00
parent ffdc5fc276
commit a1bc10246c
2 changed files with 29 additions and 20 deletions

View File

@ -14,7 +14,7 @@ from tkinter import messagebox
import json import json
import collections import collections
import zipfile import zipfile
from typing import Union, Dict, Set from typing import Union, Dict, Set, List
import requests import requests
import html import html
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
excluded_world_info: set, excluded_world_info: List[Set],
#head_length: torch.LongTensor,
head_length: int, head_length: int,
): ):
self.any_new_entries = False self.any_new_entries = False
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
**kwargs, **kwargs,
) -> bool: ) -> bool:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape assert len(self.excluded_world_info) == input_ids.shape[0]
self.any_new_entries = False self.any_new_entries = False
if(not vars.dynamicscan): if(not vars.dynamicscan):
return False return False
tail = input_ids[..., self.head_length:] tail = input_ids[..., self.head_length:]
for t in tail: for i, t in enumerate(tail):
decoded = tokenizer.decode(t) decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True) _, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info found -= self.excluded_world_info[i]
if(len(found) != 0): if(len(found) != 0):
self.any_new_entries = True self.any_new_entries = True
break break
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
#==================================================================# #==================================================================#
# Send text to generator and deal with output # Send text to generator and deal with output
#==================================================================# #==================================================================#
def generate(txt, min, max, found_entries=set()): def generate(txt, minimum, maximum, found_entries=None):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
vars.lastctx = txt vars.lastctx = txt
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
with torch.no_grad(): with torch.no_grad():
already_generated = 0 already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1 numseqs = vars.numseqs
while True: while True:
genout = generator( genout = generator(
gen_in, gen_in,
do_sample=True, do_sample=True,
min_length=min, min_length=minimum,
max_length=max-already_generated, max_length=maximum-already_generated,
repetition_penalty=vars.rep_pen, repetition_penalty=vars.rep_pen,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
already_generated += len(genout[0]) - len(gen_in[0]) already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries): if(not model.kai_scanner.any_new_entries):
break break
txt = tokenizer.decode(genout[0, -already_generated:]) assert genout.ndim >= 2
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) assert genout.shape[0] == vars.numseqs
found_entries |= _found_entries encoded = []
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) for i in range(vars.numseqs):
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device) txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat( genout = torch.cat(
( (
encoded, encoded,
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
model.config.vocab_size + vars.sp.shape[0], model.config.vocab_size + vars.sp.shape[0],
device=genout.device, device=genout.device,
) )
genout = torch.cat((soft_tokens[None], genout), dim=-1) genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1] diff = genout.shape[-1] - gen_in.shape[-1]
min += diff minimum += diff
max += diff maximum += diff
gen_in = genout gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1] model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1 numseqs = 1

View File

@ -128,7 +128,7 @@ gensettingstf = [{
"max": 1, "max": 1,
"step": 1, "step": 1,
"default": 0, "default": 0,
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other." "tooltip": "Scan the AI's output for world info keys as it's generating the output."
}] }]
gensettingsik =[{ gensettingsik =[{