diff --git a/aiserver.py b/aiserver.py index 028b20ab..96cc4ef2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -48,7 +48,8 @@ modellist = [ # Variables class vars: - lastact = "" # The last action submitted to the generator + lastact = "" # The last action received from the user + lastctx = "" # The last context submitted to the generator model = "" noai = False # Runs the script without starting up the transformers pipeline aibusy = False # Stops submissions while the AI is working @@ -69,6 +70,8 @@ class vars: badwords = [] badwordsids = [] deletewi = -1 # Temporary storage for index to delete + wirmvwhtsp = False # Whether to remove leading whitespace from WI entries + widepth = 1 # How many historical actions to scan for WI hits mode = "play" # Whether the interface is in play, memory, or edit mode editln = 0 # Which line was last selected in Edit Mode url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL @@ -521,9 +524,15 @@ def settingschanged(): # #==================================================================# def actionsubmit(data): + # Ignore new submissions if the AI is currently busy if(vars.aibusy): return set_aibusy(1) + + # If we're not continuing, store a copy of the raw input + if(data != ""): + vars.lastact = data + if(not vars.gamestarted): # Start the game vars.gamestarted = True @@ -697,6 +706,9 @@ def calcsubmit(txt): def generate(txt, min, max): print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) + # Store context in memory to use it for comparison with generated content + vars.lastctx = txt + # Clear CUDA cache if using GPU if(vars.hascuda and vars.usegpu): torch.cuda.empty_cache() @@ -736,6 +748,9 @@ def sendtocolab(txt, min, max): # Log request to console print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END)) + # Store context in memory to use it for comparison with generated content + vars.lastctx = txt + # Build request JSON data reqdata = { 'text': txt, @@ -765,21 +780,9 @@ def sendtocolab(txt, min, max): refresh_story() emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)}) - set_aibusy(0) - elif(req.status_code == 500): - errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console." - print("{0}{1}{2}".format(colors.RED, errmsg, colors.END)) - emit('from_server', {'cmd': 'errmsg', 'data': errmsg}) set_aibusy(0) else: - # Send error message to web client - er = req.json() - if("error" in er): - code = er["error"]["extensions"]["code"] - elif("errors" in er): - code = er["errors"][0]["extensions"]["code"] - - errmsg = "Colab API Error: {0} - {1}".format(req.status_code, code) + errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console." print("{0}{1}{2}".format(colors.RED, errmsg, colors.END)) emit('from_server', {'cmd': 'errmsg', 'data': errmsg}) set_aibusy(0) @@ -795,17 +798,19 @@ def formatforhtml(txt): # Strips submitted text from the text returned by the AI #==================================================================# def getnewcontent(txt): - ln = len(vars.actions) - if(ln == 0): - lastact = tokenizer.encode(vars.prompt) - else: - lastact = tokenizer.encode(vars.actions[-1]) + # If the submitted context was blank, then everything is new + if(vars.lastctx == ""): + return txt - delim = tokenizer.decode(lastact) + # Tokenize the last context and the generated content + ctxtokens = tokenizer.encode(vars.lastctx) + txttokens = tokenizer.encode(txt) + dif = (len(txttokens) - len(ctxtokens)) * -1 - split = txt.split(delim) + # Remove the context from the returned text + newtokens = txttokens[dif:] - return (split[-1]) + return tokenizer.decode(newtokens) #==================================================================# # Applies chosen formatting options to text submitted to AI @@ -1031,16 +1036,31 @@ def deletewi(num): requestwi() #==================================================================# -# Look for WI keys in text to generator +# Look for WI keys in text to generator #==================================================================# def checkworldinfo(txt): # Dont go any further if WI is empty if(len(vars.worldinfo) == 0): return - - # Join submitted text to last action - if(len(vars.actions) > 0): - txt = vars.actions[-1] + txt + + # Cache actions length + ln = len(vars.actions) + + # Don't bother calculating action history if widepth is 0 + if(vars.widepth > 0): + depth = vars.widepth + # If this is not a continue, add 1 to widepth since submitted + # text is already in action history @ -1 + if(txt != "" and vars.prompt != txt): + txt = "" + depth += 1 + + if(ln >= depth): + txt = "".join(vars.actions[(depth*-1):]) + elif(ln > 0): + txt = vars.prompt + "".join(vars.actions[(depth*-1):]) + elif(ln == 0): + txt = vars.prompt # Scan text for matches on WI keys wimem = "" @@ -1049,15 +1069,16 @@ def checkworldinfo(txt): # Split comma-separated keys keys = wi["key"].split(",") for k in keys: - # Remove leading/trailing spaces - ky = k.strip() + ky = k + # Remove leading/trailing spaces if the option is enabled + if(vars.wirmvwhtsp): + ky = k.strip() if ky in txt: wimem = wimem + wi["content"] + "\n" break return wimem - #==================================================================# # Commit changes to Memory storage #==================================================================# @@ -1199,6 +1220,8 @@ def loadRequest(): vars.memory = js["memory"] vars.actions = js["actions"] vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" # Try not to break older save files if("authorsnote" in js): @@ -1299,6 +1322,8 @@ def importgame(): vars.authornote = ref["authorsNote"] if type(ref["authorsNote"]) is str else "" vars.actions = [] vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" # Get all actions except for prompt if("actions" in ref): @@ -1350,6 +1375,8 @@ def importAidgRequest(id): vars.authornote = js["authorsNote"] vars.actions = [] vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" num = 0 for wi in js["worldInfos"]: @@ -1388,6 +1415,8 @@ def newGameRequest(): vars.savedir = getcwd()+"\stories" vars.authornote = "" vars.worldinfo = [] + vars.lastact = "" + vars.lastctx = "" # Refresh game screen sendwi() diff --git a/utils.py b/utils.py index 2b43f167..816b1af9 100644 --- a/utils.py +++ b/utils.py @@ -69,7 +69,12 @@ def removespecialchars(txt): def addsentencespacing(txt, vars): # Get last character of last action if(len(vars.actions) > 0): - lastchar = vars.actions[-1][-1] + if(len(vars.actions[-1]) > 0): + lastchar = vars.actions[-1][-1] + else: + # Last action is blank, this should never happen, but + # since it did let's bail out. + return txt else: lastchar = vars.prompt[-1] if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):