diff --git a/aiserver.py b/aiserver.py index 06152d53..085da1b5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -13,7 +13,7 @@ from tkinter import messagebox import json import collections import zipfile -from typing import Union, Tuple +from typing import Union, Dict, Set import requests import html @@ -124,6 +124,7 @@ class vars: acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) actionmode = 1 adventure = False + dynamicscan = False remote = False #==================================================================# @@ -512,7 +513,8 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END)) if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) - from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM + from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM + import transformers.generation_utils # Patch transformers to use our soft prompt def patch_causallm(cls): @@ -528,7 +530,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(vars.sp is not None): vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) inputs_embeds = torch.where( - (shifted_input_ids >= 0)[:, :, None], + (shifted_input_ids >= 0)[..., None], vars.sp[shifted_input_ids.clamp(min=0)], inputs_embeds, ) @@ -543,6 +545,52 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): except: pass + # Sets up dynamic world info scanner + class DynamicWorldInfoScanCriteria(StoppingCriteria): + def __init__( + self, + tokenizer, + excluded_world_info: set, + #head_length: torch.LongTensor, + head_length: int, + ): + self.any_new_entries = False + self.tokenizer = tokenizer + self.excluded_world_info = excluded_world_info + self.head_length = head_length + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + assert input_ids.ndim == 2 + #assert input_ids.shape[:-1] == self.head_length.shape + self.any_new_entries = False + if(not vars.dynamicscan): + return False + tail = input_ids[..., self.head_length:] + for t in tail: + decoded = tokenizer.decode(t) + _, found = checkworldinfo(decoded, force_use_txt=True) + found -= self.excluded_world_info + if(len(found) != 0): + self.any_new_entries = True + break + return self.any_new_entries + old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria + def new_get_stopping_criteria(self, *args, **kwargs): + stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) + global tokenizer + self.kai_scanner = DynamicWorldInfoScanCriteria( + tokenizer=tokenizer, + excluded_world_info=self.kai_scanner_excluded_world_info, + head_length=self.kai_scanner_head_length, + ) + stopping_criteria.append(self.kai_scanner) + return stopping_criteria + transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria + # If custom GPT Neo model was chosen if(vars.model == "NeoCustom"): model_config = open(vars.custmodpth + "/config.json", "r") @@ -901,6 +949,10 @@ def get_message(msg): vars.adventure = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setdynamicscan'): + vars.dynamicscan = msg['data'] + settingschanged() + refresh_settings() elif(not vars.remote and msg['cmd'] == 'importwi'): wiimportrequest() @@ -958,6 +1010,7 @@ def savesettings(): js["widepth"] = vars.widepth js["useprompt"] = vars.useprompt js["adventure"] = vars.adventure + js["dynamicscan"] = vars.dynamicscan # Write it if not os.path.exists('settings'): @@ -1008,6 +1061,8 @@ def loadsettings(): vars.useprompt = js["useprompt"] if("adventure" in js): vars.adventure = js["adventure"] + if("dynamicscan" in js): + vars.dynamicscan = js["dynamicscan"] file.close() @@ -1032,6 +1087,8 @@ def loadmodelsettings(): vars.rep_pen = js["rep_pen"] if("adventure" in js): vars.adventure = js["adventure"] + if("dynamicscan" in js): + vars.dynamicscan = js["dynamicscan"] if("formatoptns" in js): vars.formatoptns = js["formatoptns"] model_config.close() @@ -1148,135 +1205,140 @@ def actionback(): vars.genseqs = [] #==================================================================# -# Take submitted text and build the text to be given to generator +# #==================================================================# -def calcsubmit(txt): - anotetxt = "" # Placeholder for Author's Note text - lnanote = 0 # Placeholder for Author's Note length - forceanote = False # In case we don't have enough actions to hit A.N. depth - anoteadded = False # In case our budget runs out before we hit A.N. depth - actionlen = len(vars.actions) - +def calcsubmitbudgetheader(txt, **kwargs): # Scan for WorldInfo matches - winfo = checkworldinfo(txt) - + winfo, found_entries = checkworldinfo(txt, **kwargs) + # Add a newline to the end of memory if(vars.memory != "" and vars.memory[-1] != "\n"): mem = vars.memory + "\n" else: mem = vars.memory - + # Build Author's Note if set if(vars.authornote != ""): anotetxt = "\n[Author's note: "+vars.authornote+"]\n" + else: + anotetxt = "" + + return winfo, mem, anotetxt, found_entries + +def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions): + forceanote = False # In case we don't have enough actions to hit A.N. depth + anoteadded = False # In case our budget runs out before we hit A.N. depth + anotetkns = [] # Placeholder for Author's Note tokens + lnanote = 0 # Placeholder for Author's Note length + + # Calculate token budget + prompttkns = tokenizer.encode(vars.prompt) + lnprompt = len(prompttkns) + memtokens = tokenizer.encode(mem) + lnmem = len(memtokens) + + witokens = tokenizer.encode(winfo) + lnwi = len(witokens) + + if(anotetxt != ""): + anotetkns = tokenizer.encode(anotetxt) + lnanote = len(anotetkns) + + lnsp = vars.sp.shape[0] if vars.sp is not None else 0 + + if(vars.useprompt): + budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt + else: + budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt + + if(actionlen == 0): + # First/Prompt action + subtxt = vars.memory + winfo + anotetxt + vars.prompt + lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote + return subtxt, lnsub+1, lnsub+vars.genamt + else: + tokens = [] + + # Check if we have the action depth to hit our A.N. depth + if(anotetxt != "" and actionlen < vars.andepth): + forceanote = True + + # Get most recent action tokens up to our budget + n = 0 + for key in reversed(actions): + chunk = actions[key] + + if(budget <= 0): + break + acttkns = tokenizer.encode(chunk) + tknlen = len(acttkns) + if(tknlen < budget): + tokens = acttkns + tokens + budget -= tknlen + else: + count = budget * -1 + tokens = acttkns[count:] + tokens + budget = 0 + break + + # Inject Author's Note if we've reached the desired depth + if(n == vars.andepth-1): + if(anotetxt != ""): + tokens = anotetkns + tokens # A.N. len already taken from bdgt + anoteadded = True + n += 1 + + # If we're not using the prompt every time and there's still budget left, + # add some prompt. + if(not vars.useprompt): + if(budget > 0): + prompttkns = prompttkns[-budget:] + else: + prompttkns = [] + + # Did we get to add the A.N.? If not, do it here + if(anotetxt != ""): + if((not anoteadded) or forceanote): + tokens = memtokens + witokens + anotetkns + prompttkns + tokens + else: + tokens = memtokens + witokens + prompttkns + tokens + else: + # Prepend Memory, WI, and Prompt before action tokens + tokens = memtokens + witokens + prompttkns + tokens + + # Send completed bundle to generator + ln = len(tokens) + lnsp + return tokenizer.decode(tokens), ln+1, ln+vars.genamt + +#==================================================================# +# Take submitted text and build the text to be given to generator +#==================================================================# +def calcsubmit(txt): + anotetxt = "" # Placeholder for Author's Note text + forceanote = False # In case we don't have enough actions to hit A.N. depth + anoteadded = False # In case our budget runs out before we hit A.N. depth + actionlen = len(vars.actions) + + winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt) + # For all transformers models if(vars.model != "InferKit"): - anotetkns = [] # Placeholder for Author's Note tokens - - # Calculate token budget - prompttkns = tokenizer.encode(vars.prompt) - lnprompt = len(prompttkns) - - memtokens = tokenizer.encode(mem) - lnmem = len(memtokens) - - witokens = tokenizer.encode(winfo) - lnwi = len(witokens) - - if(anotetxt != ""): - anotetkns = tokenizer.encode(anotetxt) - lnanote = len(anotetkns) - - lnsp = vars.sp.shape[0] if vars.sp is not None else 0 - - if(vars.useprompt): - budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt - else: - budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - + subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions) if(actionlen == 0): - # First/Prompt action - subtxt = vars.memory + winfo + anotetxt + vars.prompt - lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote - if(not vars.model in ["Colab", "OAI"]): - generate(subtxt, lnsub+1, lnsub+vars.genamt) + generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): - sendtocolab(subtxt, lnsub+1, lnsub+vars.genamt) + sendtocolab(subtxt, min, max) elif(vars.model == "OAI"): - oairequest(subtxt, lnsub+1, lnsub+vars.genamt) + oairequest(subtxt, min, max) else: - tokens = [] - - # Check if we have the action depth to hit our A.N. depth - if(anotetxt != "" and actionlen < vars.andepth): - forceanote = True - - # Get most recent action tokens up to our budget - n = 0 - for key in reversed(vars.actions): - chunk = vars.actions[key] - - if(budget <= 0): - break - acttkns = tokenizer.encode(chunk) - tknlen = len(acttkns) - if(tknlen < budget): - tokens = acttkns + tokens - budget -= tknlen - else: - count = budget * -1 - tokens = acttkns[count:] + tokens - budget = 0 - break - - # Inject Author's Note if we've reached the desired depth - if(n == vars.andepth-1): - if(anotetxt != ""): - tokens = anotetkns + tokens # A.N. len already taken from bdgt - anoteadded = True - n += 1 - - # If we're not using the prompt every time and there's still budget left, - # add some prompt. - if(not vars.useprompt): - if(budget > 0): - prompttkns = prompttkns[-budget:] - else: - prompttkns = [] - - # Did we get to add the A.N.? If not, do it here - if(anotetxt != ""): - if((not anoteadded) or forceanote): - tokens = memtokens + witokens + anotetkns + prompttkns + tokens - else: - tokens = memtokens + witokens + prompttkns + tokens - else: - # Prepend Memory, WI, and Prompt before action tokens - tokens = memtokens + witokens + prompttkns + tokens - - # Send completed bundle to generator - ln = len(tokens) + lnsp - if(not vars.model in ["Colab", "OAI"]): - generate ( - tokenizer.decode(tokens), - ln+1, - ln+vars.genamt - ) + generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): - sendtocolab( - tokenizer.decode(tokens), - ln+1, - ln+vars.genamt - ) + sendtocolab(subtxt, min, max) elif(vars.model == "OAI"): - oairequest( - tokenizer.decode(tokens), - ln+1, - ln+vars.genamt - ) + oairequest(subtxt, min, max) # For InferKit web API else: @@ -1337,7 +1399,7 @@ def calcsubmit(txt): #==================================================================# # Send text to generator and deal with output #==================================================================# -def generate(txt, min, max): +def generate(txt, min, max, found_entries=set()): print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) # Store context in memory to use it for comparison with generated content @@ -1360,7 +1422,7 @@ def generate(txt, min, max): model.config.vocab_size, model.config.vocab_size + vars.sp.shape[0], ) - gen_in = torch.cat((soft_tokens[None], gen_in), dim=1) + gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) if(vars.hascuda and vars.usegpu): gen_in = gen_in.to(0) @@ -1371,22 +1433,60 @@ def generate(txt, min, max): else: gen_in = gen_in.to('cpu') + model.kai_scanner_head_length = gen_in.shape[-1] + model.kai_scanner_excluded_world_info = found_entries + + actions = vars.actions + if(vars.dynamicscan): + actions = actions.copy() + with torch.no_grad(): - genout = generator( - gen_in, - do_sample=True, - min_length=min, - max_length=max, - repetition_penalty=vars.rep_pen, - top_p=top_p, - top_k=top_k, - tfs=tfs, - temperature=vars.temp, - bad_words_ids=vars.badwordsids, - use_cache=True, - return_full_text=False, - num_return_sequences=vars.numseqs + already_generated = 0 + numseqs = vars.numseqs if not vars.dynamicscan else 1 + while True: + genout = generator( + gen_in, + do_sample=True, + min_length=min, + max_length=max-already_generated, + repetition_penalty=vars.rep_pen, + top_p=top_p, + top_k=top_k, + tfs=tfs, + temperature=vars.temp, + bad_words_ids=vars.badwordsids, + use_cache=True, + num_return_sequences=numseqs + ) + already_generated += len(genout[0]) - len(gen_in[0]) + if(not model.kai_scanner.any_new_entries): + break + txt = tokenizer.decode(genout[0, -already_generated:]) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + found_entries |= _found_entries + txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) + encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device) + genout = torch.cat( + ( + encoded, + genout[..., -already_generated:], + ), + dim=-1 ) + if(vars.sp is not None): + soft_tokens = torch.arange( + model.config.vocab_size, + model.config.vocab_size + vars.sp.shape[0], + device=genout.device, + ) + genout = torch.cat((soft_tokens[None], genout), dim=-1) + diff = genout.shape[-1] - gen_in.shape[-1] + min += diff + max += diff + gen_in = genout + model.kai_scanner_head_length = encoded.shape[-1] + numseqs = 1 + except Exception as e: emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) print("{0}{1}{2}".format(colors.RED, e, colors.END)) @@ -1394,7 +1494,8 @@ def generate(txt, min, max): return # Need to manually strip and decode tokens if we're not using a pipeline - genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout] + #already_generated = -(len(gen_in[0]) - len(tokens)) + genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout] if(len(genout) == 1): genresult(genout[0]["generated_text"]) @@ -1654,6 +1755,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True) emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True) emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True) + emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True) emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) @@ -1862,10 +1964,12 @@ def deletewi(num): #==================================================================# # Look for WI keys in text to generator #==================================================================# -def checkworldinfo(txt): +def checkworldinfo(txt, force_use_txt=False): + original_txt = txt + # Dont go any further if WI is empty if(len(vars.worldinfo) == 0): - return + return "", set() # Cache actions length ln = len(vars.actions) @@ -1875,7 +1979,7 @@ def checkworldinfo(txt): depth = vars.widepth # If this is not a continue, add 1 to widepth since submitted # text is already in action history @ -1 - if(txt != "" and vars.prompt != txt): + if(not force_use_txt and (txt != "" and vars.prompt != txt)): txt = "" depth += 1 @@ -1895,12 +1999,17 @@ def checkworldinfo(txt): txt = vars.prompt + "".join(chunks) elif(ln == 0): txt = vars.prompt - + + if(force_use_txt): + txt += original_txt + # Scan text for matches on WI keys wimem = "" + found_entries = set() for wi in vars.worldinfo: if(wi.get("constant", False)): wimem = wimem + wi["content"] + "\n" + found_entries.add(id(wi)) continue if(wi["key"] != ""): @@ -1922,15 +2031,17 @@ def checkworldinfo(txt): ksy = ks.strip() if ksy in txt: wimem = wimem + wi["content"] + "\n" + found_entries.add(id(wi)) found = True break if found: break else: wimem = wimem + wi["content"] + "\n" + found_entries.add(id(wi)) break - return wimem + return wimem, found_entries #==================================================================# # Commit changes to Memory storage diff --git a/gensettings.py b/gensettings.py index b35b6a21..c567c94c 100644 --- a/gensettings.py +++ b/gensettings.py @@ -118,6 +118,17 @@ gensettingstf = [{ "step": 1, "default": 0, "tooltip": "Turn this on if you are playing a Choose your Adventure model." + }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Dynamic WI Scan", + "id": "setdynamicscan", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other." }] gensettingsik =[{ diff --git a/static/application.js b/static/application.js index ce975326..044e471e 100644 --- a/static/application.js +++ b/static/application.js @@ -1593,6 +1593,9 @@ $(document).ready(function(){ $("#setadventure").prop('checked', msg.data).change(); // Update adventure state setadventure(msg.data); + } else if(msg.cmd == "updatedynamicscan") { + // Update toggle state + $("#setdynamicscan").prop('checked', msg.data).change(); } else if(msg.cmd == "runs_remotely") { remote = true; hide([button_savetofile, button_import, button_importwi]);