Bugfixes:

Improvements to pruning context from text returned from the AI
Colab errors should no longer throw JSON decode errors in client
Improved logic for World Info scanning
Fix for index error in addsentencespacing
This commit is contained in:
KoboldAI Dev 2021-05-18 17:59:59 -04:00
parent 3d070f057e
commit 4996e0ff46
2 changed files with 65 additions and 31 deletions

View File

@ -48,7 +48,8 @@ modellist = [
# Variables # Variables
class vars: class vars:
lastact = "" # The last action submitted to the generator lastact = "" # The last action received from the user
lastctx = "" # The last context submitted to the generator
model = "" model = ""
noai = False # Runs the script without starting up the transformers pipeline noai = False # Runs the script without starting up the transformers pipeline
aibusy = False # Stops submissions while the AI is working aibusy = False # Stops submissions while the AI is working
@ -69,6 +70,8 @@ class vars:
badwords = [] badwords = []
badwordsids = [] badwordsids = []
deletewi = -1 # Temporary storage for index to delete deletewi = -1 # Temporary storage for index to delete
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
widepth = 1 # How many historical actions to scan for WI hits
mode = "play" # Whether the interface is in play, memory, or edit mode mode = "play" # Whether the interface is in play, memory, or edit mode
editln = 0 # Which line was last selected in Edit Mode editln = 0 # Which line was last selected in Edit Mode
url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL
@ -521,9 +524,15 @@ def settingschanged():
# #
#==================================================================# #==================================================================#
def actionsubmit(data): def actionsubmit(data):
# Ignore new submissions if the AI is currently busy
if(vars.aibusy): if(vars.aibusy):
return return
set_aibusy(1) set_aibusy(1)
# If we're not continuing, store a copy of the raw input
if(data != ""):
vars.lastact = data
if(not vars.gamestarted): if(not vars.gamestarted):
# Start the game # Start the game
vars.gamestarted = True vars.gamestarted = True
@ -697,6 +706,9 @@ def calcsubmit(txt):
def generate(txt, min, max): def generate(txt, min, max):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
# Store context in memory to use it for comparison with generated content
vars.lastctx = txt
# Clear CUDA cache if using GPU # Clear CUDA cache if using GPU
if(vars.hascuda and vars.usegpu): if(vars.hascuda and vars.usegpu):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -736,6 +748,9 @@ def sendtocolab(txt, min, max):
# Log request to console # Log request to console
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END)) print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
# Store context in memory to use it for comparison with generated content
vars.lastctx = txt
# Build request JSON data # Build request JSON data
reqdata = { reqdata = {
'text': txt, 'text': txt,
@ -765,21 +780,9 @@ def sendtocolab(txt, min, max):
refresh_story() refresh_story()
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)}) emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
set_aibusy(0)
elif(req.status_code == 500):
errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console."
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
set_aibusy(0) set_aibusy(0)
else: else:
# Send error message to web client errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console."
er = req.json()
if("error" in er):
code = er["error"]["extensions"]["code"]
elif("errors" in er):
code = er["errors"][0]["extensions"]["code"]
errmsg = "Colab API Error: {0} - {1}".format(req.status_code, code)
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END)) print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}) emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
set_aibusy(0) set_aibusy(0)
@ -795,17 +798,19 @@ def formatforhtml(txt):
# Strips submitted text from the text returned by the AI # Strips submitted text from the text returned by the AI
#==================================================================# #==================================================================#
def getnewcontent(txt): def getnewcontent(txt):
ln = len(vars.actions) # If the submitted context was blank, then everything is new
if(ln == 0): if(vars.lastctx == ""):
lastact = tokenizer.encode(vars.prompt) return txt
else:
lastact = tokenizer.encode(vars.actions[-1])
delim = tokenizer.decode(lastact) # Tokenize the last context and the generated content
ctxtokens = tokenizer.encode(vars.lastctx)
txttokens = tokenizer.encode(txt)
dif = (len(txttokens) - len(ctxtokens)) * -1
split = txt.split(delim) # Remove the context from the returned text
newtokens = txttokens[dif:]
return (split[-1]) return tokenizer.decode(newtokens)
#==================================================================# #==================================================================#
# Applies chosen formatting options to text submitted to AI # Applies chosen formatting options to text submitted to AI
@ -1038,9 +1043,24 @@ def checkworldinfo(txt):
if(len(vars.worldinfo) == 0): if(len(vars.worldinfo) == 0):
return return
# Join submitted text to last action # Cache actions length
if(len(vars.actions) > 0): ln = len(vars.actions)
txt = vars.actions[-1] + txt
# Don't bother calculating action history if widepth is 0
if(vars.widepth > 0):
depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
if(txt != "" and vars.prompt != txt):
txt = ""
depth += 1
if(ln >= depth):
txt = "".join(vars.actions[(depth*-1):])
elif(ln > 0):
txt = vars.prompt + "".join(vars.actions[(depth*-1):])
elif(ln == 0):
txt = vars.prompt
# Scan text for matches on WI keys # Scan text for matches on WI keys
wimem = "" wimem = ""
@ -1049,15 +1069,16 @@ def checkworldinfo(txt):
# Split comma-separated keys # Split comma-separated keys
keys = wi["key"].split(",") keys = wi["key"].split(",")
for k in keys: for k in keys:
# Remove leading/trailing spaces ky = k
ky = k.strip() # Remove leading/trailing spaces if the option is enabled
if(vars.wirmvwhtsp):
ky = k.strip()
if ky in txt: if ky in txt:
wimem = wimem + wi["content"] + "\n" wimem = wimem + wi["content"] + "\n"
break break
return wimem return wimem
#==================================================================# #==================================================================#
# Commit changes to Memory storage # Commit changes to Memory storage
#==================================================================# #==================================================================#
@ -1199,6 +1220,8 @@ def loadRequest():
vars.memory = js["memory"] vars.memory = js["memory"]
vars.actions = js["actions"] vars.actions = js["actions"]
vars.worldinfo = [] vars.worldinfo = []
vars.lastact = ""
vars.lastctx = ""
# Try not to break older save files # Try not to break older save files
if("authorsnote" in js): if("authorsnote" in js):
@ -1299,6 +1322,8 @@ def importgame():
vars.authornote = ref["authorsNote"] if type(ref["authorsNote"]) is str else "" vars.authornote = ref["authorsNote"] if type(ref["authorsNote"]) is str else ""
vars.actions = [] vars.actions = []
vars.worldinfo = [] vars.worldinfo = []
vars.lastact = ""
vars.lastctx = ""
# Get all actions except for prompt # Get all actions except for prompt
if("actions" in ref): if("actions" in ref):
@ -1350,6 +1375,8 @@ def importAidgRequest(id):
vars.authornote = js["authorsNote"] vars.authornote = js["authorsNote"]
vars.actions = [] vars.actions = []
vars.worldinfo = [] vars.worldinfo = []
vars.lastact = ""
vars.lastctx = ""
num = 0 num = 0
for wi in js["worldInfos"]: for wi in js["worldInfos"]:
@ -1388,6 +1415,8 @@ def newGameRequest():
vars.savedir = getcwd()+"\stories" vars.savedir = getcwd()+"\stories"
vars.authornote = "" vars.authornote = ""
vars.worldinfo = [] vars.worldinfo = []
vars.lastact = ""
vars.lastctx = ""
# Refresh game screen # Refresh game screen
sendwi() sendwi()

View File

@ -69,7 +69,12 @@ def removespecialchars(txt):
def addsentencespacing(txt, vars): def addsentencespacing(txt, vars):
# Get last character of last action # Get last character of last action
if(len(vars.actions) > 0): if(len(vars.actions) > 0):
lastchar = vars.actions[-1][-1] if(len(vars.actions[-1]) > 0):
lastchar = vars.actions[-1][-1]
else:
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
else: else:
lastchar = vars.prompt[-1] lastchar = vars.prompt[-1]
if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"): if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):