Dynamic world info scan

This commit is contained in:
Gnome Ann 2021-11-03 11:54:48 -04:00
parent aa998ba5e9
commit ec8ec55256
1 changed files with 224 additions and 132 deletions

View File

@ -13,7 +13,7 @@ from tkinter import messagebox
import json
import collections
import zipfile
from typing import Union, Tuple
from typing import Union, Dict, Set
import requests
import html
@ -511,7 +511,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
import transformers.generation_utils
# Patch transformers to use our soft prompt
def patch_causallm(cls):
@ -527,7 +528,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None],
(shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
@ -542,6 +543,50 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
except:
pass
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
head_length: int,
):
self.any_new_entries = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
tail = input_ids[..., self.head_length:]
self.any_new_entries = False
for t in tail:
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
if(len(found) != 0):
self.any_new_entries = True
break
return self.any_new_entries
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
@ -1145,135 +1190,138 @@ def actionback():
vars.genseqs = []
#==================================================================#
# Take submitted text and build the text to be given to generator
#
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
def calcsubmitbudgetheader(txt, **kwargs):
# Scan for WorldInfo matches
winfo = checkworldinfo(txt)
winfo, found_entries = checkworldinfo(txt, **kwargs)
# Add a newline to the end of memory
if(vars.memory != "" and vars.memory[-1] != "\n"):
mem = vars.memory + "\n"
else:
mem = vars.memory
# Build Author's Note if set
if(vars.authornote != ""):
anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
else:
anotetxt = ""
return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt):
anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
return subtxt, lnsub+1, lnsub+vars.genamt
else:
tokens = []
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed(vars.actions):
chunk = vars.actions[key]
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n += 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens)
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(txt)
# For all transformers models
if(vars.model != "InferKit"):
anotetkns = [] # Placeholder for Author's Note tokens
# Calculate token budget
prompttkns = tokenizer.encode(vars.prompt)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt)
if(actionlen == 0):
# First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
generate(subtxt, min, max)
elif(vars.model == "Colab"):
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
oairequest(subtxt, min, max)
else:
tokens = []
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
# Get most recent action tokens up to our budget
n = 0
for key in reversed(vars.actions):
chunk = vars.actions[key]
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
n += 1
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens)
if(not vars.model in ["Colab", "OAI"]):
generate (
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
generate(subtxt, min, max)
elif(vars.model == "Colab"):
sendtocolab(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
oairequest(subtxt, min, max)
# For InferKit web API
else:
@ -1357,7 +1405,7 @@ def generate(txt, min, max):
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0)
@ -1368,22 +1416,56 @@ def generate(txt, min, max):
else:
gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = set()
with torch.no_grad():
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
already_generated = 0
numseqs = vars.numseqs
found_entries = model.kai_scanner_excluded_world_info
while True:
genout = generator(
gen_in,
do_sample=True,
min_length=min+already_generated,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs
)
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
gen_in = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
diff = gen_in.shape[-1] - genout.shape[-1]
min += diff
max += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1391,7 +1473,8 @@ def generate(txt, min, max):
return
# Need to manually strip and decode tokens if we're not using a pipeline
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
#already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -1859,10 +1942,12 @@ def deletewi(num):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt):
def checkworldinfo(txt, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty
if(len(vars.worldinfo) == 0):
return
return "", set()
# Cache actions length
ln = len(vars.actions)
@ -1872,7 +1957,7 @@ def checkworldinfo(txt):
depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
if(txt != "" and vars.prompt != txt):
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
txt = ""
depth += 1
@ -1892,12 +1977,17 @@ def checkworldinfo(txt):
txt = vars.prompt + "".join(chunks)
elif(ln == 0):
txt = vars.prompt
if(force_use_txt):
txt += original_txt
# Scan text for matches on WI keys
wimem = ""
found_entries = set()
for wi in vars.worldinfo:
if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
continue
if(wi["key"] != ""):
@ -1919,15 +2009,17 @@ def checkworldinfo(txt):
ksy = ks.strip()
if ksy in txt:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
found = True
break
if found:
break
else:
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
break
return wimem
return wimem, found_entries
#==================================================================#
# Commit changes to Memory storage