Merge pull request #30 from VE-FORBRYDERNE/dynamic-scan

Support for multiple gens per action with dynamic scan
This commit is contained in:
henk717 2021-11-17 22:30:12 +01:00 committed by GitHub
commit 26eb2cb6ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 20 deletions

View File

@ -14,7 +14,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Dict, Set
from typing import Union, Dict, Set, List
import requests
import html
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
excluded_world_info: List[Set],
head_length: int,
):
self.any_new_entries = False
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
assert len(self.excluded_world_info) == input_ids.shape[0]
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
for i, t in enumerate(tail):
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
found -= self.excluded_world_info[i]
if(len(found) != 0):
self.any_new_entries = True
break
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
def generate(txt, minimum, maximum, found_entries=None):
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
# Store context in memory to use it for comparison with generated content
vars.lastctx = txt
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
with torch.no_grad():
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
numseqs = vars.numseqs
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max-already_generated,
min_length=minimum,
max_length=maximum-already_generated,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
(
encoded,
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
minimum += diff
maximum += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1

View File

@ -128,7 +128,7 @@ gensettingstf = [{
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
"tooltip": "Scan the AI's output for world info keys as it's generating the output."
}]
gensettingsik =[{