Dynamic world info scan

This commit is contained in:
Gnome Ann
2021-11-03 11:54:48 -04:00
parent aa998ba5e9
commit ec8ec55256

View File

@ -13,7 +13,7 @@ from tkinter import messagebox
import json import json
import collections import collections
import zipfile import zipfile
from typing import Union, Tuple from typing import Union, Dict, Set
import requests import requests
import html import html
@ -511,7 +511,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai): if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
import transformers.generation_utils
# Patch transformers to use our soft prompt # Patch transformers to use our soft prompt
def patch_causallm(cls): def patch_causallm(cls):
@ -527,7 +528,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(vars.sp is not None): if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where( inputs_embeds = torch.where(
(shifted_input_ids >= 0)[:, :, None], (shifted_input_ids >= 0)[..., None],
vars.sp[shifted_input_ids.clamp(min=0)], vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds, inputs_embeds,
) )
@ -542,6 +543,50 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
except: except:
pass pass
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: set,
#head_length: torch.LongTensor,
head_length: int,
):
self.any_new_entries = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
assert input_ids.ndim == 2
#assert input_ids.shape[:-1] == self.head_length.shape
tail = input_ids[..., self.head_length:]
self.any_new_entries = False
for t in tail:
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= self.excluded_world_info
if(len(found) != 0):
self.any_new_entries = True
break
return self.any_new_entries
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
# If custom GPT Neo model was chosen # If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"): if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r") model_config = open(vars.custmodpth + "/config.json", "r")
@ -1145,17 +1190,11 @@ def actionback():
vars.genseqs = [] vars.genseqs = []
#==================================================================# #==================================================================#
# Take submitted text and build the text to be given to generator #
#==================================================================# #==================================================================#
def calcsubmit(txt): def calcsubmitbudgetheader(txt, **kwargs):
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
# Scan for WorldInfo matches # Scan for WorldInfo matches
winfo = checkworldinfo(txt) winfo, found_entries = checkworldinfo(txt, **kwargs)
# Add a newline to the end of memory # Add a newline to the end of memory
if(vars.memory != "" and vars.memory[-1] != "\n"): if(vars.memory != "" and vars.memory[-1] != "\n"):
@ -1166,10 +1205,14 @@ def calcsubmit(txt):
# Build Author's Note if set # Build Author's Note if set
if(vars.authornote != ""): if(vars.authornote != ""):
anotetxt = "\n[Author's note: "+vars.authornote+"]\n" anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
else:
anotetxt = ""
# For all transformers models return winfo, mem, anotetxt, found_entries
if(vars.model != "InferKit"):
def calcsubmitbudget(actionlen, winfo, mem, anotetxt):
anotetkns = [] # Placeholder for Author's Note tokens anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget # Calculate token budget
prompttkns = tokenizer.encode(vars.prompt) prompttkns = tokenizer.encode(vars.prompt)
@ -1196,13 +1239,7 @@ def calcsubmit(txt):
# First/Prompt action # First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.prompt subtxt = vars.memory + winfo + anotetxt + vars.prompt
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
return subtxt, lnsub+1, lnsub+vars.genamt
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, lnsub+1, lnsub+vars.genamt)
elif(vars.model == "Colab"):
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
elif(vars.model == "OAI"):
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
else: else:
tokens = [] tokens = []
@ -1255,25 +1292,36 @@ def calcsubmit(txt):
# Send completed bundle to generator # Send completed bundle to generator
ln = len(tokens) ln = len(tokens)
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions)
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(txt)
# For all transformers models
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI"]): if(not vars.model in ["Colab", "OAI"]):
generate ( generate(subtxt, min, max)
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab( sendtocolab(subtxt, min, max)
tokenizer.decode(tokens),
ln+1,
ln+vars.genamt
)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest( oairequest(subtxt, min, max)
tokenizer.decode(tokens), else:
ln+1, if(not vars.model in ["Colab", "OAI"]):
ln+vars.genamt generate(subtxt, min, max)
) elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"):
oairequest(subtxt, min, max)
# For InferKit web API # For InferKit web API
else: else:
@ -1357,7 +1405,7 @@ def generate(txt, min, max):
model.config.vocab_size, model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0], model.config.vocab_size + vars.sp.shape[0],
) )
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
if(vars.hascuda and vars.usegpu): if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(0) gen_in = gen_in.to(0)
@ -1368,11 +1416,18 @@ def generate(txt, min, max):
else: else:
gen_in = gen_in.to('cpu') gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = set()
with torch.no_grad(): with torch.no_grad():
already_generated = 0
numseqs = vars.numseqs
found_entries = model.kai_scanner_excluded_world_info
while True:
genout = generator( genout = generator(
gen_in, gen_in,
do_sample=True, do_sample=True,
min_length=min, min_length=min+already_generated,
max_length=max, max_length=max,
repetition_penalty=vars.rep_pen, repetition_penalty=vars.rep_pen,
top_p=top_p, top_p=top_p,
@ -1381,9 +1436,36 @@ def generate(txt, min, max):
temperature=vars.temp, temperature=vars.temp,
bad_words_ids=vars.badwordsids, bad_words_ids=vars.badwordsids,
use_cache=True, use_cache=True,
return_full_text=False, num_return_sequences=numseqs
num_return_sequences=vars.numseqs
) )
already_generated += len(genout[0]) - len(gen_in[0])
if(not model.kai_scanner.any_new_entries):
break
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
gen_in = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1
)
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
diff = gen_in.shape[-1] - genout.shape[-1]
min += diff
max += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
except Exception as e: except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END)) print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1391,7 +1473,8 @@ def generate(txt, min, max):
return return
# Need to manually strip and decode tokens if we're not using a pipeline # Need to manually strip and decode tokens if we're not using a pipeline
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout] #already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1): if(len(genout) == 1):
genresult(genout[0]["generated_text"]) genresult(genout[0]["generated_text"])
@ -1859,10 +1942,12 @@ def deletewi(num):
#==================================================================# #==================================================================#
# Look for WI keys in text to generator # Look for WI keys in text to generator
#==================================================================# #==================================================================#
def checkworldinfo(txt): def checkworldinfo(txt, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty # Dont go any further if WI is empty
if(len(vars.worldinfo) == 0): if(len(vars.worldinfo) == 0):
return return "", set()
# Cache actions length # Cache actions length
ln = len(vars.actions) ln = len(vars.actions)
@ -1872,7 +1957,7 @@ def checkworldinfo(txt):
depth = vars.widepth depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted # If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1 # text is already in action history @ -1
if(txt != "" and vars.prompt != txt): if(not force_use_txt and (txt != "" and vars.prompt != txt)):
txt = "" txt = ""
depth += 1 depth += 1
@ -1893,11 +1978,16 @@ def checkworldinfo(txt):
elif(ln == 0): elif(ln == 0):
txt = vars.prompt txt = vars.prompt
if(force_use_txt):
txt += original_txt
# Scan text for matches on WI keys # Scan text for matches on WI keys
wimem = "" wimem = ""
found_entries = set()
for wi in vars.worldinfo: for wi in vars.worldinfo:
if(wi.get("constant", False)): if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
continue continue
if(wi["key"] != ""): if(wi["key"] != ""):
@ -1919,15 +2009,17 @@ def checkworldinfo(txt):
ksy = ks.strip() ksy = ks.strip()
if ksy in txt: if ksy in txt:
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
found = True found = True
break break
if found: if found:
break break
else: else:
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
break break
return wimem return wimem, found_entries
#==================================================================# #==================================================================#
# Commit changes to Memory storage # Commit changes to Memory storage