Merge pull request #23 from VE-FORBRYDERNE/scan-test

Dynamic world info scan
This commit is contained in:
henk717 2021-11-10 03:31:42 +01:00 committed by GitHub
commit c2371cf801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 258 additions and 133 deletions

View File

@ -13,7 +13,7 @@ from tkinter import messagebox
import json import json
import collections import collections
import zipfile import zipfile
from typing import Union, Tuple from typing import Union, Dict, Set
import requests import requests
import html import html
@ -124,6 +124,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
actionmode = 1 actionmode = 1
adventure = False adventure = False
dynamicscan = False
remote = False remote = False
#==================================================================# #==================================================================#
@ -512,7 +513,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai): if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
import transformers.generation_utils
# Patch transformers to use our soft prompt # Patch transformers to use our soft prompt
def patch_causallm(cls): def patch_causallm(cls):
@ -528,7 +530,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.sp is not None): if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where( inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None], (shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)], vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds, inputs_embeds,
) )
@ -543,6 +545,52 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
except: except:
pass pass
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
head_length: int,
):
self.any_new_entries = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
if(len(found) != 0):
self.any_new_entries = True
break
return self.any_new_entries
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
# If custom GPT Neo model was chosen # If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"): if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r") model_config = open(vars.custmodpth + "/config.json", "r")
@ -901,6 +949,10 @@ def get_message(msg):
vars.adventure = msg['data'] vars.adventure = msg['data']
settingschanged() settingschanged()
refresh_settings() refresh_settings()
elif(msg['cmd'] == 'setdynamicscan'):
vars.dynamicscan = msg['data']
settingschanged()
refresh_settings()
elif(not vars.remote and msg['cmd'] == 'importwi'): elif(not vars.remote and msg['cmd'] == 'importwi'):
wiimportrequest() wiimportrequest()
@ -958,6 +1010,7 @@ def savesettings():
js["widepth"] = vars.widepth js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure js["adventure"] = vars.adventure
js["dynamicscan"] = vars.dynamicscan
# Write it # Write it
if not os.path.exists('settings'): if not os.path.exists('settings'):
@ -1008,6 +1061,8 @@ def loadsettings():
vars.useprompt = js["useprompt"] vars.useprompt = js["useprompt"]
if("adventure" in js): if("adventure" in js):
vars.adventure = js["adventure"] vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
file.close() file.close()
@ -1032,6 +1087,8 @@ def loadmodelsettings():
vars.rep_pen = js["rep_pen"] vars.rep_pen = js["rep_pen"]
if("adventure" in js): if("adventure" in js):
vars.adventure = js["adventure"] vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("formatoptns" in js): if("formatoptns" in js):
vars.formatoptns = js["formatoptns"] vars.formatoptns = js["formatoptns"]
model_config.close() model_config.close()
@ -1148,17 +1205,11 @@ def actionback():
vars.genseqs = [] vars.genseqs = []
#==================================================================# #==================================================================#
# Take submitted text and build the text to be given to generator #
#==================================================================# #==================================================================#
def calcsubmit(txt): def calcsubmitbudgetheader(txt, **kwargs):
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
# Scan for WorldInfo matches # Scan for WorldInfo matches
winfo = checkworldinfo(txt) winfo, found_entries = checkworldinfo(txt, **kwargs)
# Add a newline to the end of memory # Add a newline to the end of memory
if(vars.memory != "" and vars.memory[-1] != "\n"): if(vars.memory != "" and vars.memory[-1] != "\n"):
@ -1169,10 +1220,16 @@ def calcsubmit(txt):
# Build Author's Note if set # Build Author's Note if set
if(vars.authornote != ""): if(vars.authornote != ""):
anotetxt = "\n[Author's note: "+vars.authornote+"]\n" anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
else:
anotetxt = ""
# For all transformers models return winfo, mem, anotetxt, found_entries
if(vars.model != "InferKit"):
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [] # Placeholder for Author's Note tokens anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget # Calculate token budget
prompttkns = tokenizer.encode(vars.prompt) prompttkns = tokenizer.encode(vars.prompt)
@ -1199,13 +1256,7 @@ def calcsubmit(txt):
# First/Prompt action # First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
return subtxt, lnsub+1, lnsub+vars.genamt
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
elif(vars.model == "Colab"):
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
elif(vars.model == "OAI"):
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
else: else:
tokens = [] tokens = []
@ -1215,8 +1266,8 @@ def calcsubmit(txt):
# Get most recent action tokens up to our budget # Get most recent action tokens up to our budget
n = 0 n = 0
for key in reversed(vars.actions): for key in reversed(actions):
chunk = vars.actions[key] chunk = actions[key]
if(budget <= 0): if(budget <= 0):
break break
@ -1258,25 +1309,36 @@ def calcsubmit(txt):
# Send completed bundle to generator # Send completed bundle to generator
ln = len(tokens) + lnsp ln = len(tokens) + lnsp
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
# For all transformers models
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI"]): if(not vars.model in ["Colab", "OAI"]):
generate ( generate(subtxt, min, max, found_entries=found_entries)
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab( sendtocolab(subtxt, min, max)
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest( oairequest(subtxt, min, max)
tokenizer.decode(tokens), else:
ln+1, if(not vars.model in ["Colab", "OAI"]):
ln+vars.genamt generate(subtxt, min, max, found_entries=found_entries)
) elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, min, max)
# For InferKit web API # For InferKit web API
else: else:
@ -1337,7 +1399,7 @@ def calcsubmit(txt):
#==================================================================# #==================================================================#
# Send text to generator and deal with output # Send text to generator and deal with output
#==================================================================# #==================================================================#
def generate(txt, min, max): def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
@ -1360,7 +1422,7 @@ def generate(txt, min, max):
model.config.vocab_size, model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0], model.config.vocab_size + vars.sp.shape[0],
) )
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
if(vars.hascuda and vars.usegpu): if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0) gen_in = gen_in.to(0)
@ -1371,12 +1433,22 @@ def generate(txt, min, max):
else: else:
gen_in = gen_in.to('cpu') gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions
if(vars.dynamicscan):
actions = actions.copy()
with torch.no_grad(): with torch.no_grad():
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
while True:
genout = generator( genout = generator(
gen_in, gen_in,
do_sample=True, do_sample=True,
min_length=min, min_length=min,
max_length=max, max_length=max-already_generated,
repetition_penalty=vars.rep_pen, repetition_penalty=vars.rep_pen,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
@ -1384,9 +1456,37 @@ def generate(txt, min, max):
temperature=vars.temp, temperature=vars.temp,
bad_words_ids=vars.badwordsids, bad_words_ids=vars.badwordsids,
use_cache=True, use_cache=True,
return_full_text=False, num_return_sequences=numseqs
num_return_sequences=vars.numseqs
) )
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
except Exception as e: except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END)) print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1394,7 +1494,8 @@ def generate(txt, min, max):
return return
# Need to manually strip and decode tokens if we're not using a pipeline # Need to manually strip and decode tokens if we're not using a pipeline
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout] #already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1): if(len(genout) == 1):
genresult(genout[0]["generated_text"]) genresult(genout[0]["generated_text"])
@ -1654,6 +1755,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True) emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True)
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True) emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True)
emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True) emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True)
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -1862,10 +1964,12 @@ def deletewi(num):
#==================================================================# #==================================================================#
# Look for WI keys in text to generator # Look for WI keys in text to generator
#==================================================================# #==================================================================#
def checkworldinfo(txt): def checkworldinfo(txt, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty # Dont go any further if WI is empty
if(len(vars.worldinfo) == 0): if(len(vars.worldinfo) == 0):
return return "", set()
# Cache actions length # Cache actions length
ln = len(vars.actions) ln = len(vars.actions)
@ -1875,7 +1979,7 @@ def checkworldinfo(txt):
depth = vars.widepth depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted # If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1 # text is already in action history @ -1
if(txt != "" and vars.prompt != txt): if(not force_use_txt and (txt != "" and vars.prompt != txt)):
txt = "" txt = ""
depth += 1 depth += 1
@ -1896,11 +2000,16 @@ def checkworldinfo(txt):
elif(ln == 0): elif(ln == 0):
txt = vars.prompt txt = vars.prompt
if(force_use_txt):
txt += original_txt
# Scan text for matches on WI keys # Scan text for matches on WI keys
wimem = "" wimem = ""
found_entries = set()
for wi in vars.worldinfo: for wi in vars.worldinfo:
if(wi.get("constant", False)): if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
continue continue
if(wi["key"] != ""): if(wi["key"] != ""):
@ -1922,15 +2031,17 @@ def checkworldinfo(txt):
ksy = ks.strip() ksy = ks.strip()
if ksy in txt: if ksy in txt:
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
found = True found = True
break break
if found: if found:
break break
else: else:
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
break break
return wimem return wimem, found_entries
#==================================================================# #==================================================================#
# Commit changes to Memory storage # Commit changes to Memory storage

View File

@ -118,6 +118,17 @@ gensettingstf = [{
"step": 1, "step": 1,
"default": 0, "default": 0,
"tooltip": "Turn this on if you are playing a Choose your Adventure model." "tooltip": "Turn this on if you are playing a Choose your Adventure model."
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Dynamic WI Scan",
"id": "setdynamicscan",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
}] }]
gensettingsik =[{ gensettingsik =[{

View File

@ -1593,6 +1593,9 @@ $(document).ready(function(){
$("#setadventure").prop('checked', msg.data).change(); $("#setadventure").prop('checked', msg.data).change();
// Update adventure state // Update adventure state
setadventure(msg.data); setadventure(msg.data);
} else if(msg.cmd == "updatedynamicscan") {
// Update toggle state
$("#setdynamicscan").prop('checked', msg.data).change();
} else if(msg.cmd == "runs_remotely") { } else if(msg.cmd == "runs_remotely") {
remote = true; remote = true;
hide([button_savetofile, button_import, button_importwi]); hide([button_savetofile, button_import, button_importwi]);