Added option to generate multiple responses per action.

Added ability to import World Info files from AI Dungeon.
Added slider for setting World Info scan depth.
Added toggle to control whether prompt is submitted each action.
Added 'Read Only' mode with no AI to startup.
Fixed GPU/CPU choice prompt appearing when GPU isn't an option.
Added error handling to generator calls for CUDA OOM message
Added generator parameter to only return new text
This commit is contained in:
KoboldAI Dev
2021-05-29 05:46:03 -04:00
parent 2cc48e7163
commit bed1eba6eb
6 changed files with 467 additions and 110 deletions

View File

@ -44,7 +44,8 @@ modellist = [
["Custom Neo (eg Neo-horni)", "NeoCustom", ""],
["Custom GPT-2 (eg CloverEdition)", "GPT2Custom", ""],
["Google Colab", "Colab", ""],
["OpenAI API (requires API key)", "OAI", ""]
["OpenAI API (requires API key)", "OAI", ""],
["Read Only (No AI)", "ReadOnly", ""]
]
# Variables
@ -61,6 +62,7 @@ class vars:
rep_pen = 1.0 # Default generator repetition_penalty
temp = 1.0 # Default generator temperature
top_p = 1.0 # Default generator top_p
numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
prompt = "" # Prompt
memory = "" # Text submitted to memory field
@ -89,8 +91,10 @@ class vars:
importnum = -1 # Selection on import popup list
importjs = {} # Temporary storage for import data
loadselect = "" # Temporary storage for filename to load
svowname = ""
saveow = False
svowname = "" # Filename that was flagged for overwrite confirm
saveow = False # Whether or not overwrite confirm has been displayed
genseqs = [] # Temporary storage for generated sequences
useprompt = True # Whether to send the full prompt with every submit action
#==================================================================#
# Function to get model selection at startup
@ -145,7 +149,7 @@ print("{0}Welcome to the KoboldAI Client!\nSelect an AI model to continue:{1}\n"
getModelSelection()
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI"]):
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
# Test for GPU support
import torch
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -155,9 +159,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI"]):
else:
print("{0}NOT FOUND!{1}".format(colors.YELLOW, colors.END))
print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
if(vars.hascuda):
print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
print(" 1 - GPU\n 2 - CPU\n")
genselected = False
while(genselected == False):
@ -277,9 +280,12 @@ if(vars.model == "OAI"):
# Ask for ngrok url if Google Colab was selected
if(vars.model == "Colab"):
print("{0}Please enter the ngrok.io URL displayed in Google Colab:{1}\n".format(colors.CYAN, colors.END))
print("{0}Please enter the ngrok.io or trycloudflare.com URL displayed in Google Colab:{1}\n".format(colors.CYAN, colors.END))
vars.colaburl = input("URL> ") + "/request"
if(vars.model == "ReadOnly"):
vars.noai = True
# Set logging level to reduce chatter from Flask
import logging
log = logging.getLogger('werkzeug')
@ -295,7 +301,7 @@ socketio = SocketIO(app)
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
if(not vars.model in ["InferKit", "Colab", "OAI"]):
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
@ -399,22 +405,10 @@ def get_message(msg):
memsubmit(msg['data'])
# Retry Action
elif(msg['cmd'] == 'retry'):
if(vars.aibusy):
return
set_aibusy(1)
# Remove last action if possible and resubmit
if(len(vars.actions) > 0):
vars.actions.pop()
refresh_story()
calcsubmit('')
actionretry(msg['data'])
# Back/Undo Action
elif(msg['cmd'] == 'back'):
if(vars.aibusy):
return
# Remove last index of actions and refresh game screen
if(len(vars.actions) > 0):
vars.actions.pop()
refresh_story()
actionback()
# EditMode Action
elif(msg['cmd'] == 'edit'):
if(vars.mode == "play"):
@ -521,12 +515,32 @@ def get_message(msg):
elif(msg['cmd'] == 'clearoverwrite'):
vars.svowname = ""
vars.saveow = False
elif(msg['cmd'] == 'seqsel'):
selectsequence(msg['data'])
elif(msg['cmd'] == 'setnumseq'):
vars.numseqs = int(msg['data'])
emit('from_server', {'cmd': 'setlabelnumseq', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'setwidepth'):
vars.widepth = int(msg['data'])
emit('from_server', {'cmd': 'setlabelwidepth', 'data': msg['data']})
settingschanged()
elif(msg['cmd'] == 'setuseprompt'):
vars.useprompt = msg['data']
settingschanged()
elif(msg['cmd'] == 'importwi'):
wiimportrequest()
#==================================================================#
# Send start message and tell Javascript to set UI state
#==================================================================#
def setStartState():
emit('from_server', {'cmd': 'updatescreen', 'data': '<span>Welcome to <span class="color_cyan">KoboldAI Client</span>! You are running <span class="color_green">'+vars.model+'</span>.<br/>Please load a game or enter a prompt below to begin!</span>'})
txt = "<span>Welcome to <span class=\"color_cyan\">KoboldAI Client</span>! You are running <span class=\"color_green\">"+vars.model+"</span>.<br/>"
if(not vars.noai):
txt = txt + "Please load a game or enter a prompt below to begin!</span>"
else:
txt = txt + "Please load or import a story to read. There is no AI in this mode."
emit('from_server', {'cmd': 'updatescreen', 'data': txt})
emit('from_server', {'cmd': 'setgamestate', 'data': 'start'})
#==================================================================#
@ -563,6 +577,9 @@ def savesettings():
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
# Write it
file = open("client.settings", "w")
@ -599,6 +616,12 @@ def loadsettings():
vars.ikgen = js["ikgen"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("numseqs" in js):
vars.numseqs = js["numseqs"]
if("widepth" in js):
vars.widepth = js["widepth"]
if("useprompt" in js):
vars.useprompt = js["useprompt"]
file.close()
@ -628,10 +651,13 @@ def actionsubmit(data):
vars.gamestarted = True
# Save this first action as the prompt
vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'data': 'Please wait, generating story...'})
calcsubmit(data) # Run the first action through the generator
if(not vars.noai):
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'data': 'Please wait, generating story...'})
calcsubmit(data) # Run the first action through the generator
else:
refresh_story()
set_aibusy(0)
else:
# Dont append submission if it's a blank/continue action
if(data != ""):
@ -640,8 +666,39 @@ def actionsubmit(data):
# Store the result in the Action log
vars.actions.append(data)
# Off to the tokenizer!
calcsubmit(data)
if(not vars.noai):
# Off to the tokenizer!
calcsubmit(data)
else:
refresh_story()
set_aibusy(0)
#==================================================================#
#
#==================================================================#
def actionretry(data):
if(vars.noai):
emit('from_server', {'cmd': 'errmsg', 'data': "Retry function unavailable in Read Only mode."})
return
if(vars.aibusy):
return
set_aibusy(1)
# Remove last action if possible and resubmit
if(len(vars.actions) > 0):
vars.actions.pop()
refresh_story()
calcsubmit('')
#==================================================================#
#
#==================================================================#
def actionback():
if(vars.aibusy):
return
# Remove last index of actions and refresh game screen
if(len(vars.actions) > 0):
vars.actions.pop()
refresh_story()
#==================================================================#
# Take submitted text and build the text to be given to generator
@ -684,7 +741,10 @@ def calcsubmit(txt):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
if(vars.useprompt):
budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
budget = vars.max_length - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
@ -717,6 +777,7 @@ def calcsubmit(txt):
else:
count = budget * -1
tokens = acttkns[count:] + tokens
budget = 0
break
# Inject Author's Note if we've reached the desired depth
@ -724,6 +785,14 @@ def calcsubmit(txt):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
if(not vars.useprompt):
if(budget > 0):
prompttkns = prompttkns[-budget:]
else:
prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
@ -759,12 +828,15 @@ def calcsubmit(txt):
# For InferKit web API
else:
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
budget = vars.ikmax - len(vars.prompt) - len(anotetxt) - len(mem) - len(winfo) - 1
if(vars.useprompt):
budget = vars.ikmax - len(vars.prompt) - len(anotetxt) - len(mem) - len(winfo) - 1
else:
budget = vars.ikmax - len(anotetxt) - len(mem) - len(winfo) - 1
subtxt = ""
for n in range(actionlen):
@ -777,8 +849,18 @@ def calcsubmit(txt):
else:
count = budget * -1
subtxt = vars.actions[(-1-n)][count:] + subtxt
budget = 0
break
# If we're not using the prompt every time and there's still budget left,
# add some prompt.
prompt = vars.prompt
if(not vars.useprompt):
if(budget > 0):
prompt = vars.prompt[-budget:]
else:
prompt = ""
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
@ -788,11 +870,11 @@ def calcsubmit(txt):
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
subtxt = mem + winfo + anotetxt + vars.prompt + subtxt
subtxt = mem + winfo + anotetxt + prompt + subtxt
else:
subtxt = mem + winfo + vars.prompt + subtxt
subtxt = mem + winfo + prompt + subtxt
else:
subtxt = mem + winfo + vars.prompt + subtxt
subtxt = mem + winfo + prompt + subtxt
# Send it!
ikrequest(subtxt)
@ -811,26 +893,30 @@ def generate(txt, min, max):
torch.cuda.empty_cache()
# Submit input text to generator
genout = generator(
txt,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=vars.top_p,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True
)[0]["generated_text"]
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
try:
genout = generator(
txt,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=vars.top_p,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
)
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'})
print("{0}{1}{2}".format(colors.RED, e, colors.END))
set_aibusy(0)
return
# Format output before continuing
genout = applyoutputformatting(getnewcontent(genout))
# Add formatted text to Actions array and refresh the game screen
vars.actions.append(genout)
refresh_story()
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
# Clear CUDA cache again if using GPU
if(vars.hascuda and vars.usegpu):
@ -838,6 +924,52 @@ def generate(txt, min, max):
set_aibusy(0)
#==================================================================#
# Deal with a single return sequence from generate()
#==================================================================#
def genresult(genout):
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
# Format output before continuing
genout = applyoutputformatting(genout)
# Add formatted text to Actions array and refresh the game screen
vars.actions.append(genout)
refresh_story()
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
#==================================================================#
# Send generator sequences to the UI for selection
#==================================================================#
def genselect(genout):
i = 0
for result in genout:
# Apply output formatting rules to sequences
result["generated_text"] = applyoutputformatting(result["generated_text"])
print("{0}[Result {1}]\n{2}{3}".format(colors.CYAN, i, result["generated_text"], colors.END))
i += 1
# Store sequences in memory until selection is made
vars.genseqs = genout
# Send sequences to UI for selection
emit('from_server', {'cmd': 'genseqs', 'data': genout})
# Refresh story for any input text
refresh_story()
#==================================================================#
# Send selected sequence to action log and refresh UI
#==================================================================#
def selectsequence(n):
if(len(vars.genseqs) == 0):
return
vars.actions.append(vars.genseqs[int(n)]["generated_text"])
refresh_story()
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''})
vars.genseqs = []
#==================================================================#
# Send transformers-style request to ngrok/colab host
#==================================================================#
@ -855,7 +987,9 @@ def sendtocolab(txt, min, max):
'max': max,
'rep_pen': vars.rep_pen,
'temperature': vars.temp,
'top_p': vars.top_p
'top_p': vars.top_p,
'numseqs': vars.numseqs,
'retfultxt': False
}
# Create request
@ -866,16 +1000,30 @@ def sendtocolab(txt, min, max):
# Deal with the response
if(req.status_code == 200):
genout = req.json()["data"]["text"]
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
js = req.json()["data"]
# Try to be backwards compatible with outdated colab
if("text" in js):
genout = [getnewcontent(js["text"])]
else:
genout = js["seqs"]
if(len(genout) == 1):
genresult(genout[0])
else:
# Convert torch output format to transformers
seqs = []
for seq in genout:
seqs.append({"generated_text": seq})
genselect(seqs)
# Format output before continuing
genout = applyoutputformatting(getnewcontent(genout))
#genout = applyoutputformatting(getnewcontent(genout))
# Add formatted text to Actions array and refresh the game screen
vars.actions.append(genout)
refresh_story()
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
#vars.actions.append(genout)
#refresh_story()
#emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
set_aibusy(0)
else:
@ -962,12 +1110,15 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs})
else:
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
emit('from_server', {'cmd': 'updateikgen', 'data': vars.ikgen})
emit('from_server', {'cmd': 'updateanotedepth', 'data': vars.andepth})
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth})
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt})
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]})
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]})
@ -1378,6 +1529,8 @@ def saveRequest(savpath):
file.write(json.dumps(js, indent=3))
finally:
file.close()
print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END))
#==================================================================#
# Load a saved story via file browser
@ -1442,6 +1595,8 @@ def loadRequest(loadpath):
sendwi()
refresh_story()
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''})
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, path.basename(loadpath), colors.END))
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
@ -1554,6 +1709,7 @@ def importgame():
sendwi()
refresh_story()
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''})
#==================================================================#
# Import an aidg.club prompt and start a new game with it.
@ -1595,6 +1751,34 @@ def importAidgRequest(id):
refresh_story()
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
#==================================================================#
# Import World Info JSON file
#==================================================================#
def wiimportrequest():
importpath = fileops.getloadpath(vars.savedir, "Select World Info File", [("Json", "*.json")])
if(importpath):
file = open(importpath, "rb")
js = json.load(file)
if(len(js) > 0):
# If the most recent WI entry is blank, remove it.
if(not vars.worldinfo[-1]["init"]):
del vars.worldinfo[-1]
# Now grab the new stuff
num = len(vars.worldinfo)
for wi in js:
vars.worldinfo.append({
"key": wi["keys"],
"content": wi["entry"],
"num": num,
"init": True
})
num += 1
print("{0}".format(vars.worldinfo[0]))
# Refresh game screen
sendwi()
#==================================================================#
# Starts a new story
#==================================================================#