From 4996e0ff467688ccb96855fc28e56af9bc34b656 Mon Sep 17 00:00:00 2001 From: KoboldAI Dev Date: Tue, 18 May 2021 17:59:59 -0400 Subject: [PATCH] Bugfixes: Improvements to pruning context from text returned from the AI Colab errors should no longer throw JSON decode errors in client Improved logic for World Info scanning Fix for index error in addsentencespacing --- aiserver.py | 89 +++++++++++++++++++++++++++++++++++------------------ utils.py | 7 ++++- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/aiserver.py b/aiserver.py index 028b20ab..96cc4ef2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -48,7 +48,8 @@ modellist = [ # Variables class vars: - lastact = "" # The last action submitted to the generator + lastact = "" # The last action received from the user + lastctx = "" # The last context submitted to the generator model = "" noai = False # Runs the script without starting up the transformers pipeline aibusy = False # Stops submissions while the AI is working @@ -69,6 +70,8 @@ class vars: badwords = [] badwordsids = [] deletewi = -1 # Temporary storage for index to delete + wirmvwhtsp = False # Whether to remove leading whitespace from WI entries + widepth = 1 # How many historical actions to scan for WI hits mode = "play" # Whether the interface is in play, memory, or edit mode editln = 0 # Which line was last selected in Edit Mode url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL @@ -521,9 +524,15 @@ def settingschanged(): # #==================================================================# def actionsubmit(data): + # Ignore new submissions if the AI is currently busy if(vars.aibusy): return set_aibusy(1) + + # If we're not continuing, store a copy of the raw input + if(data != ""): + vars.lastact = data + if(not vars.gamestarted): # Start the game vars.gamestarted = True @@ -697,6 +706,9 @@ def calcsubmit(txt): def generate(txt, min, max): print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) + # Store context in memory to use it for comparison with generated content + vars.lastctx = txt + # Clear CUDA cache if using GPU if(vars.hascuda and vars.usegpu): torch.cuda.empty_cache() @@ -736,6 +748,9 @@ def sendtocolab(txt, min, max): # Log request to console print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END)) + # Store context in memory to use it for comparison with generated content + vars.lastctx = txt + # Build request JSON data reqdata = { 'text': txt, @@ -765,21 +780,9 @@ def sendtocolab(txt, min, max): refresh_story() emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)}) - set_aibusy(0) - elif(req.status_code == 500): - errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console." - print("{0}{1}{2}".format(colors.RED, errmsg, colors.END)) - emit('from_server', {'cmd': 'errmsg', 'data': errmsg}) set_aibusy(0) else: - # Send error message to web client - er = req.json() - if("error" in er): - code = er["error"]["extensions"]["code"] - elif("errors" in er): - code = er["errors"][0]["extensions"]["code"] - - errmsg = "Colab API Error: {0} - {1}".format(req.status_code, code) + errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console." print("{0}{1}{2}".format(colors.RED, errmsg, colors.END)) emit('from_server', {'cmd': 'errmsg', 'data': errmsg}) set_aibusy(0) @@ -795,17 +798,19 @@ def formatforhtml(txt): # Strips submitted text from the text returned by the AI #==================================================================# def getnewcontent(txt): - ln = len(vars.actions) - if(ln == 0): - lastact = tokenizer.encode(vars.prompt) - else: - lastact = tokenizer.encode(vars.actions[-1]) + # If the submitted context was blank, then everything is new + if(vars.lastctx == ""): + return txt - delim = tokenizer.decode(lastact) + # Tokenize the last context and the generated content + ctxtokens = tokenizer.encode(vars.lastctx) + txttokens = tokenizer.encode(txt) + dif = (len(txttokens) - len(ctxtokens)) * -1 - split = txt.split(delim) + # Remove the context from the returned text + newtokens = txttokens[dif:] - return (split[-1]) + return tokenizer.decode(newtokens) #==================================================================# # Applies chosen formatting options to text submitted to AI @@ -1031,16 +1036,31 @@ def deletewi(num): requestwi() #==================================================================# -# Look for WI keys in text to generator +# Look for WI keys in text to generator #==================================================================# def checkworldinfo(txt): # Dont go any further if WI is empty if(len(vars.worldinfo) == 0): return - - # Join submitted text to last action - if(len(vars.actions) > 0): - txt = vars.actions[-1] + txt + + # Cache actions length + ln = len(vars.actions) + + # Don't bother calculating action history if widepth is 0 + if(vars.widepth > 0): + depth = vars.widepth + # If this is not a continue, add 1 to widepth since submitted + # text is already in action history @ -1 + if(txt != "" and vars.prompt != txt): + txt = "" + depth += 1 + + if(ln >= depth): + txt = "".join(vars.actions[(depth*-1):]) + elif(ln > 0): + txt = vars.prompt + "".join(vars.actions[(depth*-1):]) + elif(ln == 0): + txt = vars.prompt # Scan text for matches on WI keys wimem = "" @@ -1049,15 +1069,16 @@ def checkworldinfo(txt): # Split comma-separated keys keys = wi["key"].split(",") for k in keys: - # Remove leading/trailing spaces - ky = k.strip() + ky = k + # Remove leading/trailing spaces if the option is enabled + if(vars.wirmvwhtsp): + ky = k.strip() if ky in txt: wimem = wimem + wi["content"] + "\n" break return wimem - #==================================================================# # Commit changes to Memory storage #==================================================================# @@ -1199,6 +1220,8 @@ def loadRequest(): vars.memory = js["memory"] vars.actions = js["actions"] vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" # Try not to break older save files if("authorsnote" in js): @@ -1299,6 +1322,8 @@ def importgame(): vars.authornote = ref["authorsNote"] if type(ref["authorsNote"]) is str else "" vars.actions = [] vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" # Get all actions except for prompt if("actions" in ref): @@ -1350,6 +1375,8 @@ def importAidgRequest(id): vars.authornote = js["authorsNote"] vars.actions = [] vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" num = 0 for wi in js["worldInfos"]: @@ -1388,6 +1415,8 @@ def newGameRequest(): vars.savedir = getcwd()+"\stories" vars.authornote = "" vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" # Refresh game screen sendwi() diff --git a/utils.py b/utils.py index 2b43f167..816b1af9 100644 --- a/utils.py +++ b/utils.py @@ -69,7 +69,12 @@ def removespecialchars(txt): def addsentencespacing(txt, vars): # Get last character of last action if(len(vars.actions) > 0): - lastchar = vars.actions[-1][-1] + if(len(vars.actions[-1]) > 0): + lastchar = vars.actions[-1][-1] + else: + # Last action is blank, this should never happen, but + # since it did let's bail out. + return txt else: lastchar = vars.prompt[-1] if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):