Merge pull request #23 from VE-FORBRYDERNE/scan-test
Dynamic world info scan
This commit is contained in:
commit
c2371cf801
207
aiserver.py
207
aiserver.py
|
@ -13,7 +13,7 @@ from tkinter import messagebox
|
|||
import json
|
||||
import collections
|
||||
import zipfile
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Dict, Set
|
||||
|
||||
import requests
|
||||
import html
|
||||
|
@ -124,6 +124,7 @@ class vars:
|
|||
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
||||
actionmode = 1
|
||||
adventure = False
|
||||
dynamicscan = False
|
||||
remote = False
|
||||
|
||||
#==================================================================#
|
||||
|
@ -512,7 +513,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
|||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
import transformers.generation_utils
|
||||
|
||||
# Patch transformers to use our soft prompt
|
||||
def patch_causallm(cls):
|
||||
|
@ -528,7 +530,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
if(vars.sp is not None):
|
||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||
inputs_embeds = torch.where(
|
||||
(shifted_input_ids >= 0)[:, :, None],
|
||||
(shifted_input_ids >= 0)[..., None],
|
||||
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||
inputs_embeds,
|
||||
)
|
||||
|
@ -543,6 +545,52 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
|
|||
except:
|
||||
pass
|
||||
|
||||
# Sets up dynamic world info scanner
|
||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
excluded_world_info: set,
|
||||
#head_length: torch.LongTensor,
|
||||
head_length: int,
|
||||
):
|
||||
self.any_new_entries = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
self.head_length = head_length
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
assert input_ids.ndim == 2
|
||||
#assert input_ids.shape[:-1] == self.head_length.shape
|
||||
self.any_new_entries = False
|
||||
if(not vars.dynamicscan):
|
||||
return False
|
||||
tail = input_ids[..., self.head_length:]
|
||||
for t in tail:
|
||||
decoded = tokenizer.decode(t)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
found -= self.excluded_world_info
|
||||
if(len(found) != 0):
|
||||
self.any_new_entries = True
|
||||
break
|
||||
return self.any_new_entries
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
global tokenizer
|
||||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||
tokenizer=tokenizer,
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
head_length=self.kai_scanner_head_length,
|
||||
)
|
||||
stopping_criteria.append(self.kai_scanner)
|
||||
return stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||
|
||||
# If custom GPT Neo model was chosen
|
||||
if(vars.model == "NeoCustom"):
|
||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||
|
@ -901,6 +949,10 @@ def get_message(msg):
|
|||
vars.adventure = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setdynamicscan'):
|
||||
vars.dynamicscan = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(not vars.remote and msg['cmd'] == 'importwi'):
|
||||
wiimportrequest()
|
||||
|
||||
|
@ -958,6 +1010,7 @@ def savesettings():
|
|||
js["widepth"] = vars.widepth
|
||||
js["useprompt"] = vars.useprompt
|
||||
js["adventure"] = vars.adventure
|
||||
js["dynamicscan"] = vars.dynamicscan
|
||||
|
||||
# Write it
|
||||
if not os.path.exists('settings'):
|
||||
|
@ -1008,6 +1061,8 @@ def loadsettings():
|
|||
vars.useprompt = js["useprompt"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("dynamicscan" in js):
|
||||
vars.dynamicscan = js["dynamicscan"]
|
||||
|
||||
file.close()
|
||||
|
||||
|
@ -1032,6 +1087,8 @@ def loadmodelsettings():
|
|||
vars.rep_pen = js["rep_pen"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("dynamicscan" in js):
|
||||
vars.dynamicscan = js["dynamicscan"]
|
||||
if("formatoptns" in js):
|
||||
vars.formatoptns = js["formatoptns"]
|
||||
model_config.close()
|
||||
|
@ -1148,17 +1205,11 @@ def actionback():
|
|||
vars.genseqs = []
|
||||
|
||||
#==================================================================#
|
||||
# Take submitted text and build the text to be given to generator
|
||||
#
|
||||
#==================================================================#
|
||||
def calcsubmit(txt):
|
||||
anotetxt = "" # Placeholder for Author's Note text
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
actionlen = len(vars.actions)
|
||||
|
||||
def calcsubmitbudgetheader(txt, **kwargs):
|
||||
# Scan for WorldInfo matches
|
||||
winfo = checkworldinfo(txt)
|
||||
winfo, found_entries = checkworldinfo(txt, **kwargs)
|
||||
|
||||
# Add a newline to the end of memory
|
||||
if(vars.memory != "" and vars.memory[-1] != "\n"):
|
||||
|
@ -1169,10 +1220,16 @@ def calcsubmit(txt):
|
|||
# Build Author's Note if set
|
||||
if(vars.authornote != ""):
|
||||
anotetxt = "\n[Author's note: "+vars.authornote+"]\n"
|
||||
else:
|
||||
anotetxt = ""
|
||||
|
||||
# For all transformers models
|
||||
if(vars.model != "InferKit"):
|
||||
return winfo, mem, anotetxt, found_entries
|
||||
|
||||
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
anotetkns = [] # Placeholder for Author's Note tokens
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.prompt)
|
||||
|
@ -1199,13 +1256,7 @@ def calcsubmit(txt):
|
|||
# First/Prompt action
|
||||
subtxt = vars.memory + winfo + anotetxt + vars.prompt
|
||||
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
|
||||
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
generate(subtxt, lnsub+1, lnsub+vars.genamt)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, lnsub+1, lnsub+vars.genamt)
|
||||
return subtxt, lnsub+1, lnsub+vars.genamt
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
|
@ -1215,8 +1266,8 @@ def calcsubmit(txt):
|
|||
|
||||
# Get most recent action tokens up to our budget
|
||||
n = 0
|
||||
for key in reversed(vars.actions):
|
||||
chunk = vars.actions[key]
|
||||
for key in reversed(actions):
|
||||
chunk = actions[key]
|
||||
|
||||
if(budget <= 0):
|
||||
break
|
||||
|
@ -1258,25 +1309,36 @@ def calcsubmit(txt):
|
|||
|
||||
# Send completed bundle to generator
|
||||
ln = len(tokens) + lnsp
|
||||
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
|
||||
|
||||
#==================================================================#
|
||||
# Take submitted text and build the text to be given to generator
|
||||
#==================================================================#
|
||||
def calcsubmit(txt):
|
||||
anotetxt = "" # Placeholder for Author's Note text
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
actionlen = len(vars.actions)
|
||||
|
||||
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
|
||||
|
||||
# For all transformers models
|
||||
if(vars.model != "InferKit"):
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
|
||||
if(actionlen == 0):
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
generate (
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
sendtocolab(subtxt, min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(
|
||||
tokenizer.decode(tokens),
|
||||
ln+1,
|
||||
ln+vars.genamt
|
||||
)
|
||||
oairequest(subtxt, min, max)
|
||||
else:
|
||||
if(not vars.model in ["Colab", "OAI"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, min, max)
|
||||
|
||||
# For InferKit web API
|
||||
else:
|
||||
|
@ -1337,7 +1399,7 @@ def calcsubmit(txt):
|
|||
#==================================================================#
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
def generate(txt, min, max):
|
||||
def generate(txt, min, max, found_entries=set()):
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
|
@ -1360,7 +1422,7 @@ def generate(txt, min, max):
|
|||
model.config.vocab_size,
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=1)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
gen_in = gen_in.to(0)
|
||||
|
@ -1371,12 +1433,22 @@ def generate(txt, min, max):
|
|||
else:
|
||||
gen_in = gen_in.to('cpu')
|
||||
|
||||
model.kai_scanner_head_length = gen_in.shape[-1]
|
||||
model.kai_scanner_excluded_world_info = found_entries
|
||||
|
||||
actions = vars.actions
|
||||
if(vars.dynamicscan):
|
||||
actions = actions.copy()
|
||||
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
numseqs = vars.numseqs if not vars.dynamicscan else 1
|
||||
while True:
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max,
|
||||
max_length=max-already_generated,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
|
@ -1384,9 +1456,37 @@ def generate(txt, min, max):
|
|||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
return_full_text=False,
|
||||
num_return_sequences=vars.numseqs
|
||||
num_return_sequences=numseqs
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
if(not model.kai_scanner.any_new_entries):
|
||||
break
|
||||
txt = tokenizer.decode(genout[0, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries |= _found_entries
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
|
||||
genout = torch.cat(
|
||||
(
|
||||
encoded,
|
||||
genout[..., -already_generated:],
|
||||
),
|
||||
dim=-1
|
||||
)
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
device=genout.device,
|
||||
)
|
||||
genout = torch.cat((soft_tokens[None], genout), dim=-1)
|
||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||
min += diff
|
||||
max += diff
|
||||
gen_in = genout
|
||||
model.kai_scanner_head_length = encoded.shape[-1]
|
||||
numseqs = 1
|
||||
|
||||
except Exception as e:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||
|
@ -1394,7 +1494,8 @@ def generate(txt, min, max):
|
|||
return
|
||||
|
||||
# Need to manually strip and decode tokens if we're not using a pipeline
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout]
|
||||
#already_generated = -(len(gen_in[0]) - len(tokens))
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
|
@ -1654,6 +1755,7 @@ def refresh_settings():
|
|||
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
|
||||
|
||||
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
|
||||
|
@ -1862,10 +1964,12 @@ def deletewi(num):
|
|||
#==================================================================#
|
||||
# Look for WI keys in text to generator
|
||||
#==================================================================#
|
||||
def checkworldinfo(txt):
|
||||
def checkworldinfo(txt, force_use_txt=False):
|
||||
original_txt = txt
|
||||
|
||||
# Dont go any further if WI is empty
|
||||
if(len(vars.worldinfo) == 0):
|
||||
return
|
||||
return "", set()
|
||||
|
||||
# Cache actions length
|
||||
ln = len(vars.actions)
|
||||
|
@ -1875,7 +1979,7 @@ def checkworldinfo(txt):
|
|||
depth = vars.widepth
|
||||
# If this is not a continue, add 1 to widepth since submitted
|
||||
# text is already in action history @ -1
|
||||
if(txt != "" and vars.prompt != txt):
|
||||
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
|
||||
txt = ""
|
||||
depth += 1
|
||||
|
||||
|
@ -1896,11 +2000,16 @@ def checkworldinfo(txt):
|
|||
elif(ln == 0):
|
||||
txt = vars.prompt
|
||||
|
||||
if(force_use_txt):
|
||||
txt += original_txt
|
||||
|
||||
# Scan text for matches on WI keys
|
||||
wimem = ""
|
||||
found_entries = set()
|
||||
for wi in vars.worldinfo:
|
||||
if(wi.get("constant", False)):
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
continue
|
||||
|
||||
if(wi["key"] != ""):
|
||||
|
@ -1922,15 +2031,17 @@ def checkworldinfo(txt):
|
|||
ksy = ks.strip()
|
||||
if ksy in txt:
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
found = True
|
||||
break
|
||||
if found:
|
||||
break
|
||||
else:
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
break
|
||||
|
||||
return wimem
|
||||
return wimem, found_entries
|
||||
|
||||
#==================================================================#
|
||||
# Commit changes to Memory storage
|
||||
|
|
|
@ -118,6 +118,17 @@ gensettingstf = [{
|
|||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Turn this on if you are playing a Choose your Adventure model."
|
||||
},
|
||||
{
|
||||
"uitype": "toggle",
|
||||
"unit": "bool",
|
||||
"label": "Dynamic WI Scan",
|
||||
"id": "setdynamicscan",
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other."
|
||||
}]
|
||||
|
||||
gensettingsik =[{
|
||||
|
|
|
@ -1593,6 +1593,9 @@ $(document).ready(function(){
|
|||
$("#setadventure").prop('checked', msg.data).change();
|
||||
// Update adventure state
|
||||
setadventure(msg.data);
|
||||
} else if(msg.cmd == "updatedynamicscan") {
|
||||
// Update toggle state
|
||||
$("#setdynamicscan").prop('checked', msg.data).change();
|
||||
} else if(msg.cmd == "runs_remotely") {
|
||||
remote = true;
|
||||
hide([button_savetofile, button_import, button_importwi]);
|
||||
|
|
Loading…
Reference in New Issue