mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Bugfixes:
Expanded bad_word flagging for square brackets to combat Author's Note leakage World Info should now work properly if you have an Author's Note defined Set generator to use cache to improve performance of custom Neo models Added error handling for Colab disconnections Now using tokenized & detokenized version of last action to parse out new content Updated readme
This commit is contained in:
52
aiserver.py
52
aiserver.py
@ -66,6 +66,8 @@ class vars:
|
||||
andepth = 3 # How far back in history to append author's note
|
||||
actions = []
|
||||
worldinfo = []
|
||||
badwords = []
|
||||
badwordsids = []
|
||||
deletewi = -1 # Temporary storage for index to delete
|
||||
mode = "play" # Whether the interface is in play, memory, or edit mode
|
||||
editln = 0 # Which line was last selected in Edit Mode
|
||||
@ -114,6 +116,16 @@ def getModelSelection():
|
||||
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection()
|
||||
|
||||
#==================================================================#
|
||||
# Return all keys in tokenizer dictionary containing char
|
||||
#==================================================================#
|
||||
def gettokenids(char):
|
||||
keys = []
|
||||
for key in vocab_keys:
|
||||
if(key.find(char) != -1):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
#==================================================================#
|
||||
@ -238,6 +250,13 @@ if(not vars.model in ["InferKit", "Colab"]):
|
||||
else:
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
|
||||
# Suppress Author's Note by flagging square brackets
|
||||
vocab = tokenizer.get_vocab()
|
||||
vocab_keys = vocab.keys()
|
||||
vars.badwords = gettokenids("[")
|
||||
for key in vars.badwords:
|
||||
vars.badwordsids.append([vocab[key]])
|
||||
|
||||
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
|
||||
else:
|
||||
# If we're running Colab, we still need a tokenizer.
|
||||
@ -512,6 +531,7 @@ def actionsubmit(data):
|
||||
vars.prompt = data
|
||||
# Clear the startup text from game screen
|
||||
emit('from_server', {'cmd': 'updatescreen', 'data': 'Please wait, generating story...'})
|
||||
|
||||
calcsubmit(data) # Run the first action through the generator
|
||||
else:
|
||||
# Dont append submission if it's a blank/continue action
|
||||
@ -528,7 +548,6 @@ def actionsubmit(data):
|
||||
# Take submitted text and build the text to be given to generator
|
||||
#==================================================================#
|
||||
def calcsubmit(txt):
|
||||
vars.lastact = txt # Store most recent action in memory (is this still needed?)
|
||||
anotetxt = "" # Placeholder for Author's Note text
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
@ -608,13 +627,15 @@ def calcsubmit(txt):
|
||||
# Did we get to add the A.N.? If not, do it here
|
||||
if(anotetxt != ""):
|
||||
if((not anoteadded) or forceanote):
|
||||
tokens = memtokens + anotetkns + prompttkns + tokens
|
||||
tokens = memtokens + witokens + anotetkns + prompttkns + tokens
|
||||
else:
|
||||
tokens = memtokens + prompttkns + tokens
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
else:
|
||||
# Prepend Memory, WI, and Prompt before action tokens
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
|
||||
|
||||
|
||||
# Send completed bundle to generator
|
||||
ln = len(tokens)
|
||||
|
||||
@ -680,11 +701,6 @@ def generate(txt, min, max):
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Suppress Author's Note by flagging square brackets
|
||||
bad_words = []
|
||||
bad_words.append(tokenizer("[", add_prefix_space=True).input_ids)
|
||||
bad_words.append(tokenizer("[", add_prefix_space=False).input_ids)
|
||||
|
||||
# Submit input text to generator
|
||||
genout = generator(
|
||||
txt,
|
||||
@ -694,7 +710,8 @@ def generate(txt, min, max):
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=vars.top_p,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=bad_words
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True
|
||||
)[0]["generated_text"]
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
|
||||
@ -748,6 +765,11 @@ def sendtocolab(txt, min, max):
|
||||
refresh_story()
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
|
||||
|
||||
set_aibusy(0)
|
||||
elif(req.status_code == 500):
|
||||
errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console."
|
||||
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
|
||||
set_aibusy(0)
|
||||
else:
|
||||
# Send error message to web client
|
||||
@ -756,8 +778,9 @@ def sendtocolab(txt, min, max):
|
||||
code = er["error"]["extensions"]["code"]
|
||||
elif("errors" in er):
|
||||
code = er["errors"][0]["extensions"]["code"]
|
||||
|
||||
|
||||
errmsg = "Colab API Error: {0} - {1}".format(req.status_code, code)
|
||||
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
|
||||
set_aibusy(0)
|
||||
|
||||
@ -774,12 +797,11 @@ def formatforhtml(txt):
|
||||
def getnewcontent(txt):
|
||||
ln = len(vars.actions)
|
||||
if(ln == 0):
|
||||
delim = vars.prompt
|
||||
lastact = tokenizer.encode(vars.prompt)
|
||||
else:
|
||||
delim = vars.actions[-1]
|
||||
lastact = tokenizer.encode(vars.actions[-1])
|
||||
|
||||
# Fix issue with tokenizer replacing space+period with period
|
||||
delim = delim.replace(" .", ".")
|
||||
delim = tokenizer.decode(lastact)
|
||||
|
||||
split = txt.split(delim)
|
||||
|
||||
@ -1216,6 +1238,7 @@ def importRequest():
|
||||
file = open(importpath, "rb")
|
||||
vars.importjs = json.load(file)
|
||||
|
||||
# If a bundle file is being imported, select just the Adventures object
|
||||
if type(vars.importjs) is dict and "stories" in vars.importjs:
|
||||
vars.importjs = vars.importjs["stories"]
|
||||
|
||||
@ -1259,6 +1282,7 @@ def importgame():
|
||||
# Copy game contents to vars
|
||||
vars.gamestarted = True
|
||||
|
||||
# Support for different versions of export script
|
||||
if("actions" in ref):
|
||||
if(len(ref["actions"]) > 0):
|
||||
vars.prompt = ref["actions"][0]["text"]
|
||||
|
Reference in New Issue
Block a user