Merge pull request #23 from VE-FORBRYDERNE/scan-test

Dynamic world info scan
This commit is contained in:
henk717 2021-11-10 03:31:42 +01:00 committed by GitHub
commit c2371cf801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 258 additions and 133 deletions

View File

@ -13,7 +13,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Tuple
from typing import Union, Dict, Set
import requests
import html
@ -124,6 +124,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
actionmode = 1
adventure = False
dynamicscan = False
remote = False
#==================================================================#
@ -512,7 +513,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
import transformers.generation_utils
# Patch transformers to use our soft prompt
def patch_causallm(cls):
@ -528,7 +530,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None],
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
@ -543,6 +545,52 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
except:
pass
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
head_length: int,
):
self.any_new_entries = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
if(len(found) != 0):
self.any_new_entries = True
break
return self.any_new_entries
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
@ -901,6 +949,10 @@ def get_message(msg):
vars.adventure = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setdynamicscan'):
vars.dynamicscan = msg['data']
settingschanged()
refresh_settings()
elif(not vars.remote and msg['cmd'] == 'importwi'):
wiimportrequest()
@ -958,6 +1010,7 @@ def savesettings():
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["dynamicscan"] = vars.dynamicscan
# Write it
if not os.path.exists('settings'):
@ -1008,6 +1061,8 @@ def loadsettings():
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
file.close()
@ -1032,6 +1087,8 @@ def loadmodelsettings():
vars.rep_pen = js["rep_pen"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
model_config.close()
@ -1148,17 +1205,11 @@ def actionback():
vars.genseqs = []
#==================================================================#
# Take submitted text and build the text to be given to generator
#
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
def calcsubmitbudgetheader(txt, **kwargs):
# Scan for WorldInfo matches
winfo = checkworldinfo(txt)
winfo, found_entries = checkworldinfo(txt, **kwargs)
# Add a newline to the end of memory
if(vars.memory != "" and vars.memory[-1] != "\n"):
@ -1169,10 +1220,16 @@ def calcsubmit(txt):
# Build Author's Note if set
if(vars.authornote != ""):
anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
else:
anotetxt = ""
# For all transformers models
if(vars.model != "InferKit"):
return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
@ -1199,13 +1256,7 @@ def calcsubmit(txt):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
elif(vars.model == "Colab"):
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
elif(vars.model == "OAI"):
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
return subtxt, lnsub+1, lnsub+vars.genamt
else:
tokens = []
@ -1215,8 +1266,8 @@ def calcsubmit(txt):
# Get most recent action tokens up to our budget
n = 0
for key in reversed(vars.actions):
chunk = vars.actions[key]
for key in reversed(actions):
chunk = actions[key]
if(budget <= 0):
break
@ -1258,25 +1309,36 @@ def calcsubmit(txt):
# Send completed bundle to generator
ln = len(tokens) + lnsp
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
# For all transformers models
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI"]):
generate (
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
oairequest(subtxt, min, max)
else:
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, min, max)
# For InferKit web API
else:
@ -1337,7 +1399,7 @@ def calcsubmit(txt):
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
def generate(txt, min, max):
def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
# Store context in memory to use it for comparison with generated content
@ -1360,7 +1422,7 @@ def generate(txt, min, max):
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0)
@ -1371,12 +1433,22 @@ def generate(txt, min, max):
else:
gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions
if(vars.dynamicscan):
actions = actions.copy()
with torch.no_grad():
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max,
max_length=max-already_generated,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
@ -1384,9 +1456,37 @@ def generate(txt, min, max):
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1394,7 +1494,8 @@ def generate(txt, min, max):
return
# Need to manually strip and decode tokens if we're not using a pipeline
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
#already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -1654,6 +1755,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True)
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True)
emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True)
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -1862,10 +1964,12 @@ def deletewi(num):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt):
def checkworldinfo(txt, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty
if(len(vars.worldinfo) == 0):
return
return "", set()
# Cache actions length
ln = len(vars.actions)
@ -1875,7 +1979,7 @@ def checkworldinfo(txt):
depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
if(txt != "" and vars.prompt != txt):
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
txt = ""
depth += 1
@ -1896,11 +2000,16 @@ def checkworldinfo(txt):
elif(ln == 0):
txt = vars.prompt
if(force_use_txt):
txt += original_txt
# Scan text for matches on WI keys
wimem = ""
found_entries = set()
for wi in vars.worldinfo:
if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
continue
if(wi["key"] != ""):
@ -1922,15 +2031,17 @@ def checkworldinfo(txt):
ksy = ks.strip()
if ksy in txt:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
found = True
break
if found:
break
else:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
break
return wimem
return wimem, found_entries
#==================================================================#
# Commit changes to Memory storage

View File

@ -118,6 +118,17 @@ gensettingstf = [{
"step": 1,
"default": 0,
"tooltip": "Turn this on if you are playing a Choose your Adventure model."
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Dynamic WI Scan",
"id": "setdynamicscan",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
}]
gensettingsik =[{

View File

@ -1593,6 +1593,9 @@ $(document).ready(function(){
$("#setadventure").prop('checked', msg.data).change();
// Update adventure state
setadventure(msg.data);
} else if(msg.cmd == "updatedynamicscan") {
// Update toggle state
$("#setdynamicscan").prop('checked', msg.data).change();
} else if(msg.cmd == "runs_remotely") {
remote = true;
hide([button_savetofile, button_import, button_importwi]);