Merge pull request #23 from VE-FORBRYDERNE/scan-test

Dynamic world info scan
This commit is contained in:
henk717 2021-11-10 03:31:42 +01:00 committed by GitHub
commit c2371cf801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 258 additions and 133 deletions

View File

@ -13,7 +13,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Tuple
from typing import Union, Dict, Set
import requests
import html
@ -124,6 +124,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
actionmode = 1
adventure = False
dynamicscan = False
remote = False
#==================================================================#
@ -512,7 +513,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
import transformers.generation_utils
# Patch transformers to use our soft prompt
def patch_causallm(cls):
@ -528,7 +530,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None],
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
@ -543,6 +545,52 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
except:
pass
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
head_length: int,
):
self.any_new_entries = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
if(len(found) != 0):
self.any_new_entries = True
break
return self.any_new_entries
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
@ -901,6 +949,10 @@ def get_message(msg):
vars.adventure = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setdynamicscan'):
vars.dynamicscan = msg['data']
settingschanged()
refresh_settings()
elif(not vars.remote and msg['cmd'] == 'importwi'):
wiimportrequest()
@ -958,6 +1010,7 @@ def savesettings():
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["dynamicscan"] = vars.dynamicscan
# Write it
if not os.path.exists('settings'):
@ -1008,6 +1061,8 @@ def loadsettings():
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
file.close()
@ -1032,6 +1087,8 @@ def loadmodelsettings():
vars.rep_pen = js["rep_pen"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
model_config.close()
@ -1148,135 +1205,140 @@ def actionback():
vars.genseqs = []
#==================================================================#
# Take submitted text and build the text to be given to generator
#
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
def calcsubmitbudgetheader(txt, **kwargs):
# Scan for WorldInfo matches
winfo = checkworldinfo(txt)
winfo, found_entries = checkworldinfo(txt, **kwargs)
# Add a newline to the end of memory
if(vars.memory != "" and vars.memory[-1] != "\n"):
mem = vars.memory + "\n"
else:
mem = vars.memory
# Build Author's Note if set
if(vars.authornote != ""):
anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
else:
anotetxt = ""
return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
return subtxt, lnsub+1, lnsub+vars.genamt
else:
tokens = []
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed(actions):
chunk = actions[key]
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n += 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens) + lnsp
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
# For all transformers models
if(vars.model != "InferKit"):
anotetkns = [] # Placeholder for Author's Note tokens
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
oairequest(subtxt, min, max)
else:
tokens = []
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed(vars.actions):
chunk = vars.actions[key]
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n += 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens) + lnsp
if(not vars.model in ["Colab", "OAI"]):
generate (
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
oairequest(subtxt, min, max)
# For InferKit web API
else:
@ -1337,7 +1399,7 @@ def calcsubmit(txt):
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
def generate(txt, min, max):
def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
# Store context in memory to use it for comparison with generated content
@ -1360,7 +1422,7 @@ def generate(txt, min, max):
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0)
@ -1371,22 +1433,60 @@ def generate(txt, min, max):
else:
gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions
if(vars.dynamicscan):
actions = actions.copy()
with torch.no_grad():
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max-already_generated,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1394,7 +1494,8 @@ def generate(txt, min, max):
return
# Need to manually strip and decode tokens if we're not using a pipeline
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
#already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -1654,6 +1755,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True)
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True)
emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True)
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -1862,10 +1964,12 @@ def deletewi(num):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt):
def checkworldinfo(txt, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty
if(len(vars.worldinfo) == 0):
return
return "", set()
# Cache actions length
ln = len(vars.actions)
@ -1875,7 +1979,7 @@ def checkworldinfo(txt):
depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
if(txt != "" and vars.prompt != txt):
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
txt = ""
depth += 1
@ -1895,12 +1999,17 @@ def checkworldinfo(txt):
txt = vars.prompt + "".join(chunks)
elif(ln == 0):
txt = vars.prompt
if(force_use_txt):
txt += original_txt
# Scan text for matches on WI keys
wimem = ""
found_entries = set()
for wi in vars.worldinfo:
if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
continue
if(wi["key"] != ""):
@ -1922,15 +2031,17 @@ def checkworldinfo(txt):
ksy = ks.strip()
if ksy in txt:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
found = True
break
if found:
break
else:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
break
return wimem
return wimem, found_entries
#==================================================================#
# Commit changes to Memory storage

View File

@ -118,6 +118,17 @@ gensettingstf = [{
"step": 1,
"default": 0,
"tooltip": "Turn this on if you are playing a Choose your Adventure model."
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Dynamic WI Scan",
"id": "setdynamicscan",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
}]
gensettingsik =[{

View File

@ -1593,6 +1593,9 @@ $(document).ready(function(){
$("#setadventure").prop('checked', msg.data).change();
// Update adventure state
setadventure(msg.data);
} else if(msg.cmd == "updatedynamicscan") {
// Update toggle state
$("#setdynamicscan").prop('checked', msg.data).change();
} else if(msg.cmd == "runs_remotely") {
remote = true;
hide([button_savetofile, button_import, button_importwi]);