Dynamic world info scan
This commit is contained in:
parent
aa998ba5e9
commit
ec8ec55256
356
aiserver.py
356
aiserver.py
|
@ -13,7 +13,7 @@ from tkinter import messagebox
|
|||
import json
|
||||
import collections
|
||||
import zipfile
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Dict, Set
|
||||
|
||||
import requests
|
||||
import html
|
||||
|
@ -511,7 +511,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
|||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
import transformers.generation_utils
|
||||
|
||||
# Patch transformers to use our soft prompt
|
||||
def patch_causallm(cls):
|
||||
|
@ -527,7 +528,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
if(vars.sp is not None):
|
||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||
inputs_embeds = torch.where(
|
||||
(shifted_input_ids >= 0)[:, :, None],
|
||||
(shifted_input_ids >= 0)[..., None],
|
||||
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||
inputs_embeds,
|
||||
)
|
||||
|
@ -542,6 +543,50 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
except:
|
||||
pass
|
||||
|
||||
# Sets up dynamic world info scanner
|
||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
excluded_world_info: set,
|
||||
#head_length: torch.LongTensor,
|
||||
head_length: int,
|
||||
):
|
||||
self.any_new_entries = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
self.head_length = head_length
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
#assert input_ids.shape[:-1] == self.head_length.shape
|
||||
tail = input_ids[..., self.head_length:]
|
||||
self.any_new_entries = False
|
||||
for t in tail:
|
||||
decoded = tokenizer.decode(t)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
found -= self.excluded_world_info
|
||||
if(len(found) != 0):
|
||||
self.any_new_entries = True
|
||||
break
|
||||
return self.any_new_entries
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
global tokenizer
|
||||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||
tokenizer=tokenizer,
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
head_length=self.kai_scanner_head_length,
|
||||
)
|
||||
stopping_criteria.append(self.kai_scanner)
|
||||
return stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||
|
||||
# If custom GPT Neo model was chosen
|
||||
if(vars.model == "NeoCustom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
|
@ -1145,135 +1190,138 @@ def actionback():
|
|||
vars.genseqs = []
|
||||
|
||||
#==================================================================#
|
||||
# Take submitted text and build the text to be given to generator
|
||||
#
|
||||
#==================================================================#
|
||||
def calcsubmit(txt):
|
||||
anotetxt = "" # Placeholder for Author's Note text
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
actionlen = len(vars.actions)
|
||||
|
||||
def calcsubmitbudgetheader(txt, **kwargs):
|
||||
# Scan for WorldInfo matches
|
||||
winfo = checkworldinfo(txt)
|
||||
|
||||
winfo, found_entries = checkworldinfo(txt, **kwargs)
|
||||
|
||||
# Add a newline to the end of memory
|
||||
if(vars.memory != "" and vars.memory[-1] != "\n"):
|
||||
mem = vars.memory + "\n"
|
||||
else:
|
||||
mem = vars.memory
|
||||
|
||||
|
||||
# Build Author's Note if set
|
||||
if(vars.authornote != ""):
|
||||
anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
|
||||
else:
|
||||
anotetxt = ""
|
||||
|
||||
return winfo, mem, anotetxt, found_entries
|
||||
|
||||
def calcsubmitbudget(actionlen, winfo, mem, anotetxt):
|
||||
anotetkns = [] # Placeholder for Author's Note tokens
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.prompt)
|
||||
lnprompt = len(prompttkns)
|
||||
|
||||
memtokens = tokenizer.encode(mem)
|
||||
lnmem = len(memtokens)
|
||||
|
||||
witokens = tokenizer.encode(winfo)
|
||||
lnwi = len(witokens)
|
||||
|
||||
if(anotetxt != ""):
|
||||
anotetkns = tokenizer.encode(anotetxt)
|
||||
lnanote = len(anotetkns)
|
||||
|
||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||
|
||||
if(vars.useprompt):
|
||||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
||||
else:
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
|
||||
|
||||
if(actionlen == 0):
|
||||
# First/Prompt action
|
||||
subtxt = vars.memory + winfo + anotetxt + vars.prompt
|
||||
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
|
||||
return subtxt, lnsub+1, lnsub+vars.genamt
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
# Check if we have the action depth to hit our A.N. depth
|
||||
if(anotetxt != "" and actionlen < vars.andepth):
|
||||
forceanote = True
|
||||
|
||||
# Get most recent action tokens up to our budget
|
||||
n = 0
|
||||
for key in reversed(vars.actions):
|
||||
chunk = vars.actions[key]
|
||||
|
||||
if(budget <= 0):
|
||||
break
|
||||
acttkns = tokenizer.encode(chunk)
|
||||
tknlen = len(acttkns)
|
||||
if(tknlen < budget):
|
||||
tokens = acttkns + tokens
|
||||
budget -= tknlen
|
||||
else:
|
||||
count = budget * -1
|
||||
tokens = acttkns[count:] + tokens
|
||||
budget = 0
|
||||
break
|
||||
|
||||
# Inject Author's Note if we've reached the desired depth
|
||||
if(n == vars.andepth-1):
|
||||
if(anotetxt != ""):
|
||||
tokens = anotetkns + tokens # A.N. len already taken from bdgt
|
||||
anoteadded = True
|
||||
n += 1
|
||||
|
||||
# If we're not using the prompt every time and there's still budget left,
|
||||
# add some prompt.
|
||||
if(not vars.useprompt):
|
||||
if(budget > 0):
|
||||
prompttkns = prompttkns[-budget:]
|
||||
else:
|
||||
prompttkns = []
|
||||
|
||||
# Did we get to add the A.N.? If not, do it here
|
||||
if(anotetxt != ""):
|
||||
if((not anoteadded) or forceanote):
|
||||
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
|
||||
else:
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
else:
|
||||
# Prepend Memory, WI, and Prompt before action tokens
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
|
||||
# Send completed bundle to generator
|
||||
ln = len(tokens)
|
||||
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
|
||||
|
||||
#==================================================================#
|
||||
# Take submitted text and build the text to be given to generator
|
||||
#==================================================================#
|
||||
def calcsubmit(txt):
|
||||
anotetxt = "" # Placeholder for Author's Note text
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
actionlen = len(vars.actions)
|
||||
|
||||
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(txt)
|
||||
|
||||
# For all transformers models
|
||||
if(vars.model != "InferKit"):
|
||||
anotetkns = [] # Placeholder for Author's Note tokens
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.prompt)
|
||||
lnprompt = len(prompttkns)
|
||||
|
||||
memtokens = tokenizer.encode(mem)
|
||||
lnmem = len(memtokens)
|
||||
|
||||
witokens = tokenizer.encode(winfo)
|
||||
lnwi = len(witokens)
|
||||
|
||||
if(anotetxt != ""):
|
||||
anotetkns = tokenizer.encode(anotetxt)
|
||||
lnanote = len(anotetkns)
|
||||
|
||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||
|
||||
if(vars.useprompt):
|
||||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
||||
else:
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
|
||||
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt)
|
||||
if(actionlen == 0):
|
||||
# First/Prompt action
|
||||
subtxt = vars.memory + winfo + anotetxt + vars.prompt
|
||||
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
|
||||
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
generate(subtxt, lnsub+1, lnsub+vars.genamt)
|
||||
generate(subtxt, min, max)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
|
||||
sendtocolab(subtxt, min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
|
||||
oairequest(subtxt, min, max)
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
# Check if we have the action depth to hit our A.N. depth
|
||||
if(anotetxt != "" and actionlen < vars.andepth):
|
||||
forceanote = True
|
||||
|
||||
# Get most recent action tokens up to our budget
|
||||
n = 0
|
||||
for key in reversed(vars.actions):
|
||||
chunk = vars.actions[key]
|
||||
|
||||
if(budget <= 0):
|
||||
break
|
||||
acttkns = tokenizer.encode(chunk)
|
||||
tknlen = len(acttkns)
|
||||
if(tknlen < budget):
|
||||
tokens = acttkns + tokens
|
||||
budget -= tknlen
|
||||
else:
|
||||
count = budget * -1
|
||||
tokens = acttkns[count:] + tokens
|
||||
budget = 0
|
||||
break
|
||||
|
||||
# Inject Author's Note if we've reached the desired depth
|
||||
if(n == vars.andepth-1):
|
||||
if(anotetxt != ""):
|
||||
tokens = anotetkns + tokens # A.N. len already taken from bdgt
|
||||
anoteadded = True
|
||||
n += 1
|
||||
|
||||
# If we're not using the prompt every time and there's still budget left,
|
||||
# add some prompt.
|
||||
if(not vars.useprompt):
|
||||
if(budget > 0):
|
||||
prompttkns = prompttkns[-budget:]
|
||||
else:
|
||||
prompttkns = []
|
||||
|
||||
# Did we get to add the A.N.? If not, do it here
|
||||
if(anotetxt != ""):
|
||||
if((not anoteadded) or forceanote):
|
||||
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
|
||||
else:
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
else:
|
||||
# Prepend Memory, WI, and Prompt before action tokens
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
|
||||
# Send completed bundle to generator
|
||||
ln = len(tokens)
|
||||
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
generate (
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
generate(subtxt, min, max)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
sendtocolab(subtxt, min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
oairequest(subtxt, min, max)
|
||||
|
||||
# For InferKit web API
|
||||
else:
|
||||
|
@ -1357,7 +1405,7 @@ def generate(txt, min, max):
|
|||
model.config.vocab_size,
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
gen_in = gen_in.to(0)
|
||||
|
@ -1368,22 +1416,56 @@ def generate(txt, min, max):
|
|||
else:
|
||||
gen_in = gen_in.to('cpu')
|
||||
|
||||
model.kai_scanner_head_length = gen_in.shape[-1]
|
||||
model.kai_scanner_excluded_world_info = set()
|
||||
|
||||
with torch.no_grad():
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
return_full_text=False,
|
||||
num_return_sequences=vars.numseqs
|
||||
already_generated = 0
|
||||
numseqs = vars.numseqs
|
||||
found_entries = model.kai_scanner_excluded_world_info
|
||||
while True:
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min+already_generated,
|
||||
max_length=max,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=numseqs
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
if(not model.kai_scanner.any_new_entries):
|
||||
break
|
||||
txt = tokenizer.decode(genout[0, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries |= _found_entries
|
||||
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt)
|
||||
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
||||
gen_in = torch.cat(
|
||||
(
|
||||
encoded,
|
||||
genout[..., -already_generated:],
|
||||
),
|
||||
dim=-1
|
||||
)
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
diff = gen_in.shape[-1] - genout.shape[-1]
|
||||
min += diff
|
||||
max += diff
|
||||
gen_in = genout
|
||||
model.kai_scanner_head_length = encoded.shape[-1]
|
||||
numseqs = 1
|
||||
|
||||
except Exception as e:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||
|
@ -1391,7 +1473,8 @@ def generate(txt, min, max):
|
|||
return
|
||||
|
||||
# Need to manually strip and decode tokens if we're not using a pipeline
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
|
||||
#already_generated = -(len(gen_in[0]) - len(tokens))
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
|
@ -1859,10 +1942,12 @@ def deletewi(num):
|
|||
#==================================================================#
|
||||
# Look for WI keys in text to generator
|
||||
#==================================================================#
|
||||
def checkworldinfo(txt):
|
||||
def checkworldinfo(txt, force_use_txt=False):
|
||||
original_txt = txt
|
||||
|
||||
# Dont go any further if WI is empty
|
||||
if(len(vars.worldinfo) == 0):
|
||||
return
|
||||
return "", set()
|
||||
|
||||
# Cache actions length
|
||||
ln = len(vars.actions)
|
||||
|
@ -1872,7 +1957,7 @@ def checkworldinfo(txt):
|
|||
depth = vars.widepth
|
||||
# If this is not a continue, add 1 to widepth since submitted
|
||||
# text is already in action history @ -1
|
||||
if(txt != "" and vars.prompt != txt):
|
||||
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
|
||||
txt = ""
|
||||
depth += 1
|
||||
|
||||
|
@ -1892,12 +1977,17 @@ def checkworldinfo(txt):
|
|||
txt = vars.prompt + "".join(chunks)
|
||||
elif(ln == 0):
|
||||
txt = vars.prompt
|
||||
|
||||
|
||||
if(force_use_txt):
|
||||
txt += original_txt
|
||||
|
||||
# Scan text for matches on WI keys
|
||||
wimem = ""
|
||||
found_entries = set()
|
||||
for wi in vars.worldinfo:
|
||||
if(wi.get("constant", False)):
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
continue
|
||||
|
||||
if(wi["key"] != ""):
|
||||
|
@ -1919,15 +2009,17 @@ def checkworldinfo(txt):
|
|||
ksy = ks.strip()
|
||||
if ksy in txt:
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
found = True
|
||||
break
|
||||
if found:
|
||||
break
|
||||
else:
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
break
|
||||
|
||||
return wimem
|
||||
return wimem, found_entries
|
||||
|
||||
#==================================================================#
|
||||
# Commit changes to Memory storage
|
||||
|
|
Loading…
Reference in New Issue