Bugfixes:

Expanded bad_word flagging for square brackets to combat Author's Note leakage
World Info should now work properly if you have an Author's Note defined
Set generator to use cache to improve performance of custom Neo models
Added error handling for Colab disconnections
Now using tokenized & detokenized version of last action to parse out new content
Updated readme
This commit is contained in:
KoboldAI Dev
2021-05-17 20:28:18 -04:00
parent 2721a5e64a
commit 3d070f057e
2 changed files with 62 additions and 40 deletions

View File

@ -66,6 +66,8 @@ class vars:
andepth = 3 # How far back in history to append author's note
actions = []
worldinfo = []
badwords = []
badwordsids = []
deletewi = -1 # Temporary storage for index to delete
mode = "play" # Whether the interface is in play, memory, or edit mode
editln = 0 # Which line was last selected in Edit Mode
@ -114,6 +116,16 @@ def getModelSelection():
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
getModelSelection()
#==================================================================#
# Return all keys in tokenizer dictionary containing char
#==================================================================#
def gettokenids(char):
keys = []
for key in vocab_keys:
if(key.find(char) != -1):
keys.append(key)
return keys
#==================================================================#
# Startup
#==================================================================#
@ -238,6 +250,13 @@ if(not vars.model in ["InferKit", "Colab"]):
else:
generator = pipeline('text-generation', model=vars.model)
# Suppress Author's Note by flagging square brackets
vocab = tokenizer.get_vocab()
vocab_keys = vocab.keys()
vars.badwords = gettokenids("[")
for key in vars.badwords:
vars.badwordsids.append([vocab[key]])
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
else:
# If we're running Colab, we still need a tokenizer.
@ -512,6 +531,7 @@ def actionsubmit(data):
vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'data': 'Please wait, generating story...'})
calcsubmit(data) # Run the first action through the generator
else:
# Dont append submission if it's a blank/continue action
@ -528,7 +548,6 @@ def actionsubmit(data):
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
vars.lastact = txt # Store most recent action in memory (is this still needed?)
anotetxt = "" # Placeholder for Author's Note text
lnanote = 0 # Placeholder for Author's Note length
forceanote = False # In case we don't have enough actions to hit A.N. depth
@ -608,13 +627,15 @@ def calcsubmit(txt):
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
tokens = memtokens + anotetkns + prompttkns + tokens
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = memtokens + prompttkns + tokens
tokens = memtokens + witokens + prompttkns + tokens
else:
# Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator
ln = len(tokens)
@ -680,11 +701,6 @@ def generate(txt, min, max):
if(vars.hascuda and vars.usegpu):
torch.cuda.empty_cache()
# Suppress Author's Note by flagging square brackets
bad_words = []
bad_words.append(tokenizer("[", add_prefix_space=True).input_ids)
bad_words.append(tokenizer("[", add_prefix_space=False).input_ids)
# Submit input text to generator
genout = generator(
txt,
@ -694,7 +710,8 @@ def generate(txt, min, max):
repetition_penalty=vars.rep_pen,
top_p=vars.top_p,
temperature=vars.temp,
bad_words_ids=bad_words
bad_words_ids=vars.badwordsids,
use_cache=True
)[0]["generated_text"]
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
@ -748,6 +765,11 @@ def sendtocolab(txt, min, max):
refresh_story()
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
set_aibusy(0)
elif(req.status_code == 500):
errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console."
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
set_aibusy(0)
else:
# Send error message to web client
@ -756,8 +778,9 @@ def sendtocolab(txt, min, max):
code = er["error"]["extensions"]["code"]
elif("errors" in er):
code = er["errors"][0]["extensions"]["code"]
errmsg = "Colab API Error: {0} - {1}".format(req.status_code, code)
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
set_aibusy(0)
@ -774,12 +797,11 @@ def formatforhtml(txt):
def getnewcontent(txt):
ln = len(vars.actions)
if(ln == 0):
delim = vars.prompt
lastact = tokenizer.encode(vars.prompt)
else:
delim = vars.actions[-1]
lastact = tokenizer.encode(vars.actions[-1])
# Fix issue with tokenizer replacing space+period with period
delim = delim.replace(" .", ".")
delim = tokenizer.decode(lastact)
split = txt.split(delim)
@ -1216,6 +1238,7 @@ def importRequest():
file = open(importpath, "rb")
vars.importjs = json.load(file)
# If a bundle file is being imported, select just the Adventures object
if type(vars.importjs) is dict and "stories" in vars.importjs:
vars.importjs = vars.importjs["stories"]
@ -1259,6 +1282,7 @@ def importgame():
# Copy game contents to vars
vars.gamestarted = True
# Support for different versions of export script
if("actions" in ref):
if(len(ref["actions"]) > 0):
vars.prompt = ref["actions"][0]["text"]