Merge pull request #30 from VE-FORBRYDERNE/dynamic-scan
Support for multiple gens per action with dynamic scan
This commit is contained in:
commit
26eb2cb6ce
43
aiserver.py
43
aiserver.py
|
@ -14,7 +14,7 @@ from tkinter import messagebox
|
|||
import json
|
||||
import collections
|
||||
import zipfile
|
||||
from typing import Union, Dict, Set
|
||||
from typing import Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
import html
|
||||
|
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
excluded_world_info: set,
|
||||
#head_length: torch.LongTensor,
|
||||
excluded_world_info: List[Set],
|
||||
head_length: int,
|
||||
):
|
||||
self.any_new_entries = False
|
||||
|
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
**kwargs,
|
||||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
#assert input_ids.shape[:-1] == self.head_length.shape
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
self.any_new_entries = False
|
||||
if(not vars.dynamicscan):
|
||||
return False
|
||||
tail = input_ids[..., self.head_length:]
|
||||
for t in tail:
|
||||
for i, t in enumerate(tail):
|
||||
decoded = tokenizer.decode(t)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
found -= self.excluded_world_info
|
||||
found -= self.excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
self.any_new_entries = True
|
||||
break
|
||||
|
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
|
|||
#==================================================================#
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
def generate(txt, min, max, found_entries=set()):
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
|
||||
def generate(txt, minimum, maximum, found_entries=None):
|
||||
if(found_entries is None):
|
||||
found_entries = set()
|
||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
vars.lastctx = txt
|
||||
|
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
|
|||
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
numseqs = vars.numseqs if not vars.dynamicscan else 1
|
||||
numseqs = vars.numseqs
|
||||
while True:
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max-already_generated,
|
||||
min_length=minimum,
|
||||
max_length=maximum-already_generated,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
|
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
|
|||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
if(not model.kai_scanner.any_new_entries):
|
||||
break
|
||||
txt = tokenizer.decode(genout[0, -already_generated:])
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == vars.numseqs
|
||||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries |= _found_entries
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
(
|
||||
encoded,
|
||||
|
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
|
|||
model.config.vocab_size + vars.sp.shape[0],
|
||||
device=genout.device,
|
||||
)
|
||||
genout = torch.cat((soft_tokens[None], genout), dim=-1)
|
||||
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
|
||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||
min += diff
|
||||
max += diff
|
||||
minimum += diff
|
||||
maximum += diff
|
||||
gen_in = genout
|
||||
model.kai_scanner_head_length = encoded.shape[-1]
|
||||
numseqs = 1
|
||||
|
|
|
@ -128,7 +128,7 @@ gensettingstf = [{
|
|||
"max": 1,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
|
||||
"tooltip": "Scan the AI's output for world info keys as it's generating the output."
|
||||
}]
|
||||
|
||||
gensettingsik =[{
|
||||
|
|
Loading…
Reference in New Issue