Merge pull request #30 from VE-FORBRYDERNE/dynamic-scan
Support for multiple gens per action with dynamic scan
This commit is contained in:
commit
26eb2cb6ce
43
aiserver.py
43
aiserver.py
|
@ -14,7 +14,7 @@ from tkinter import messagebox
|
||||||
import json
|
import json
|
||||||
import collections
|
import collections
|
||||||
import zipfile
|
import zipfile
|
||||||
from typing import Union, Dict, Set
|
from typing import Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import html
|
import html
|
||||||
|
@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
excluded_world_info: set,
|
excluded_world_info: List[Set],
|
||||||
#head_length: torch.LongTensor,
|
|
||||||
head_length: int,
|
head_length: int,
|
||||||
):
|
):
|
||||||
self.any_new_entries = False
|
self.any_new_entries = False
|
||||||
|
@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
#assert input_ids.shape[:-1] == self.head_length.shape
|
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||||
self.any_new_entries = False
|
self.any_new_entries = False
|
||||||
if(not vars.dynamicscan):
|
if(not vars.dynamicscan):
|
||||||
return False
|
return False
|
||||||
tail = input_ids[..., self.head_length:]
|
tail = input_ids[..., self.head_length:]
|
||||||
for t in tail:
|
for i, t in enumerate(tail):
|
||||||
decoded = tokenizer.decode(t)
|
decoded = tokenizer.decode(t)
|
||||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||||
found -= self.excluded_world_info
|
found -= self.excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if(len(found) != 0):
|
||||||
self.any_new_entries = True
|
self.any_new_entries = True
|
||||||
break
|
break
|
||||||
|
@ -1423,8 +1422,12 @@ def calcsubmit(txt):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Send text to generator and deal with output
|
# Send text to generator and deal with output
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def generate(txt, min, max, found_entries=set()):
|
def generate(txt, minimum, maximum, found_entries=None):
|
||||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
|
if(found_entries is None):
|
||||||
|
found_entries = set()
|
||||||
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
vars.lastctx = txt
|
vars.lastctx = txt
|
||||||
|
@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
already_generated = 0
|
already_generated = 0
|
||||||
numseqs = vars.numseqs if not vars.dynamicscan else 1
|
numseqs = vars.numseqs
|
||||||
while True:
|
while True:
|
||||||
genout = generator(
|
genout = generator(
|
||||||
gen_in,
|
gen_in,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
min_length=min,
|
min_length=minimum,
|
||||||
max_length=max-already_generated,
|
max_length=maximum-already_generated,
|
||||||
repetition_penalty=vars.rep_pen,
|
repetition_penalty=vars.rep_pen,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()):
|
||||||
already_generated += len(genout[0]) - len(gen_in[0])
|
already_generated += len(genout[0]) - len(gen_in[0])
|
||||||
if(not model.kai_scanner.any_new_entries):
|
if(not model.kai_scanner.any_new_entries):
|
||||||
break
|
break
|
||||||
txt = tokenizer.decode(genout[0, -already_generated:])
|
assert genout.ndim >= 2
|
||||||
|
assert genout.shape[0] == vars.numseqs
|
||||||
|
encoded = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||||
found_entries |= _found_entries
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||||
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
|
||||||
|
max_length = len(max(encoded, key=len))
|
||||||
|
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||||
genout = torch.cat(
|
genout = torch.cat(
|
||||||
(
|
(
|
||||||
encoded,
|
encoded,
|
||||||
|
@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()):
|
||||||
model.config.vocab_size + vars.sp.shape[0],
|
model.config.vocab_size + vars.sp.shape[0],
|
||||||
device=genout.device,
|
device=genout.device,
|
||||||
)
|
)
|
||||||
genout = torch.cat((soft_tokens[None], genout), dim=-1)
|
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
|
||||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||||
min += diff
|
minimum += diff
|
||||||
max += diff
|
maximum += diff
|
||||||
gen_in = genout
|
gen_in = genout
|
||||||
model.kai_scanner_head_length = encoded.shape[-1]
|
model.kai_scanner_head_length = encoded.shape[-1]
|
||||||
numseqs = 1
|
numseqs = 1
|
||||||
|
|
|
@ -128,7 +128,7 @@ gensettingstf = [{
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"step": 1,
|
"step": 1,
|
||||||
"default": 0,
|
"default": 0,
|
||||||
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
|
"tooltip": "Scan the AI's output for world info keys as it's generating the output."
|
||||||
}]
|
}]
|
||||||
|
|
||||||
gensettingsik =[{
|
gensettingsik =[{
|
||||||
|
|
Loading…
Reference in New Issue