Merge pull request #23 from VE-FORBRYDERNE/scan-test

Dynamic world info scan
This commit is contained in:
henk717
2021-11-10 03:31:42 +01:00
committed by GitHub
3 changed files with 258 additions and 133 deletions

View File

@ -13,7 +13,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Tuple
from typing import Union, Dict, Set
import requests
import html
@ -124,6 +124,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
actionmode = 1
adventure = False
dynamicscan = False
remote = False
#==================================================================#
@ -512,7 +513,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
import transformers.generation_utils
# Patch transformers to use our soft prompt
def patch_causallm(cls):
@ -528,7 +530,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None],
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
@ -543,6 +545,52 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
except:
pass
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
head_length: int,
):
self.any_new_entries = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
self.any_new_entries = False
if(not vars.dynamicscan):
return False
tail = input_ids[..., self.head_length:]
for t in tail:
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
if(len(found) != 0):
self.any_new_entries = True
break
return self.any_new_entries
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
@ -901,6 +949,10 @@ def get_message(msg):
vars.adventure = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setdynamicscan'):
vars.dynamicscan = msg['data']
settingschanged()
refresh_settings()
elif(not vars.remote and msg['cmd'] == 'importwi'):
wiimportrequest()
@ -958,6 +1010,7 @@ def savesettings():
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["dynamicscan"] = vars.dynamicscan
# Write it
if not os.path.exists('settings'):
@ -1008,6 +1061,8 @@ def loadsettings():
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
file.close()
@ -1032,6 +1087,8 @@ def loadmodelsettings():
vars.rep_pen = js["rep_pen"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
model_config.close()
@ -1148,135 +1205,140 @@ def actionback():
vars.genseqs = []
#==================================================================#
# Take submitted text and build the text to be given to generator
#
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
def calcsubmitbudgetheader(txt, **kwargs):
# Scan for WorldInfo matches
winfo = checkworldinfo(txt)
winfo, found_entries = checkworldinfo(txt, **kwargs)
# Add a newline to the end of memory
if(vars.memory != "" and vars.memory[-1] != "\n"):
mem = vars.memory + "\n"
else:
mem = vars.memory
# Build Author's Note if set
if(vars.authornote != ""):
anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
else:
anotetxt = ""
return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
return subtxt, lnsub+1, lnsub+vars.genamt
else:
tokens = []
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed(actions):
chunk = actions[key]
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n += 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens) + lnsp
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
# For all transformers models
if(vars.model != "InferKit"):
anotetkns = [] # Placeholder for Author's Note tokens
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
oairequest(subtxt, min, max)
else:
tokens = []
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed(vars.actions):
chunk = vars.actions[key]
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n += 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens) + lnsp
if(not vars.model in ["Colab", "OAI"]):
generate (
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"):
sendtocolab(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
oairequest(subtxt, min, max)
# For InferKit web API
else:
@ -1337,7 +1399,7 @@ def calcsubmit(txt):
#==================================================================#
# Send text to generator and deal with output
#==================================================================#
def generate(txt, min, max):
def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
# Store context in memory to use it for comparison with generated content
@ -1360,7 +1422,7 @@ def generate(txt, min, max):
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0)
@ -1371,22 +1433,60 @@ def generate(txt, min, max):
else:
gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions
if(vars.dynamicscan):
actions = actions.copy()
with torch.no_grad():
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max-already_generated,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat((soft_tokens[None], genout), dim=-1)
diff = genout.shape[-1] - gen_in.shape[-1]
min += diff
max += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1394,7 +1494,8 @@ def generate(txt, min, max):
return
# Need to manually strip and decode tokens if we're not using a pipeline
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
#already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -1654,6 +1755,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True)
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True)
emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True)
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -1862,10 +1964,12 @@ def deletewi(num):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt):
def checkworldinfo(txt, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty
if(len(vars.worldinfo) == 0):
return
return "", set()
# Cache actions length
ln = len(vars.actions)
@ -1875,7 +1979,7 @@ def checkworldinfo(txt):
depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
if(txt != "" and vars.prompt != txt):
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
txt = ""
depth += 1
@ -1895,12 +1999,17 @@ def checkworldinfo(txt):
txt = vars.prompt + "".join(chunks)
elif(ln == 0):
txt = vars.prompt
if(force_use_txt):
txt += original_txt
# Scan text for matches on WI keys
wimem = ""
found_entries = set()
for wi in vars.worldinfo:
if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
continue
if(wi["key"] != ""):
@ -1922,15 +2031,17 @@ def checkworldinfo(txt):
ksy = ks.strip()
if ksy in txt:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
found = True
break
if found:
break
else:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
break
return wimem
return wimem, found_entries
#==================================================================#
# Commit changes to Memory storage