diff --git a/UPDATE YOUR COLAB NOTEBOOK.txt b/UPDATE YOUR COLAB NOTEBOOK.txt
new file mode 100644
index 00000000..67a821f4
--- /dev/null
+++ b/UPDATE YOUR COLAB NOTEBOOK.txt
@@ -0,0 +1,3 @@
+If you use Google Colab to run your models, and you made a local copy of the Colab notebook in Google Drive instead of using the community notebook, you MUST make a new copy of the community notebook to use the new multiple-sequence generation feature. The link is below:
+
+https://colab.research.google.com/drive/1uGe9f4ruIQog3RLxfUsoThakvLpHjIkX?usp=sharing
\ No newline at end of file
diff --git a/aiserver.py b/aiserver.py
index 74cd5fd3..81c11275 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -44,7 +44,8 @@ modellist = [
["Custom Neo (eg Neo-horni)", "NeoCustom", ""],
["Custom GPT-2 (eg CloverEdition)", "GPT2Custom", ""],
["Google Colab", "Colab", ""],
- ["OpenAI API (requires API key)", "OAI", ""]
+ ["OpenAI API (requires API key)", "OAI", ""],
+ ["Read Only (No AI)", "ReadOnly", ""]
]
# Variables
@@ -61,6 +62,7 @@ class vars:
rep_pen = 1.0 # Default generator repetition_penalty
temp = 1.0 # Default generator temperature
top_p = 1.0 # Default generator top_p
+ numseqs = 1 # Number of sequences to ask the generator to create
gamestarted = False # Whether the game has started (disables UI elements)
prompt = "" # Prompt
memory = "" # Text submitted to memory field
@@ -89,8 +91,10 @@ class vars:
importnum = -1 # Selection on import popup list
importjs = {} # Temporary storage for import data
loadselect = "" # Temporary storage for filename to load
- svowname = ""
- saveow = False
+ svowname = "" # Filename that was flagged for overwrite confirm
+ saveow = False # Whether or not overwrite confirm has been displayed
+ genseqs = [] # Temporary storage for generated sequences
+ useprompt = True # Whether to send the full prompt with every submit action
#==================================================================#
# Function to get model selection at startup
@@ -145,7 +149,7 @@ print("{0}Welcome to the KoboldAI Client!\nSelect an AI model to continue:{1}\n"
getModelSelection()
# If transformers model was selected & GPU available, ask to use CPU or GPU
-if(not vars.model in ["InferKit", "Colab", "OAI"]):
+if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
# Test for GPU support
import torch
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@@ -155,9 +159,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI"]):
else:
print("{0}NOT FOUND!{1}".format(colors.YELLOW, colors.END))
- print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
-
if(vars.hascuda):
+ print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
print(" 1 - GPU\n 2 - CPU\n")
genselected = False
while(genselected == False):
@@ -277,9 +280,12 @@ if(vars.model == "OAI"):
# Ask for ngrok url if Google Colab was selected
if(vars.model == "Colab"):
- print("{0}Please enter the ngrok.io URL displayed in Google Colab:{1}\n".format(colors.CYAN, colors.END))
+ print("{0}Please enter the ngrok.io or trycloudflare.com URL displayed in Google Colab:{1}\n".format(colors.CYAN, colors.END))
vars.colaburl = input("URL> ") + "/request"
+if(vars.model == "ReadOnly"):
+ vars.noai = True
+
# Set logging level to reduce chatter from Flask
import logging
log = logging.getLogger('werkzeug')
@@ -295,7 +301,7 @@ socketio = SocketIO(app)
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
-if(not vars.model in ["InferKit", "Colab", "OAI"]):
+if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
@@ -399,22 +405,10 @@ def get_message(msg):
memsubmit(msg['data'])
# Retry Action
elif(msg['cmd'] == 'retry'):
- if(vars.aibusy):
- return
- set_aibusy(1)
- # Remove last action if possible and resubmit
- if(len(vars.actions) > 0):
- vars.actions.pop()
- refresh_story()
- calcsubmit('')
+ actionretry(msg['data'])
# Back/Undo Action
elif(msg['cmd'] == 'back'):
- if(vars.aibusy):
- return
- # Remove last index of actions and refresh game screen
- if(len(vars.actions) > 0):
- vars.actions.pop()
- refresh_story()
+ actionback()
# EditMode Action
elif(msg['cmd'] == 'edit'):
if(vars.mode == "play"):
@@ -521,12 +515,32 @@ def get_message(msg):
elif(msg['cmd'] == 'clearoverwrite'):
vars.svowname = ""
vars.saveow = False
+ elif(msg['cmd'] == 'seqsel'):
+ selectsequence(msg['data'])
+ elif(msg['cmd'] == 'setnumseq'):
+ vars.numseqs = int(msg['data'])
+ emit('from_server', {'cmd': 'setlabelnumseq', 'data': msg['data']})
+ settingschanged()
+ elif(msg['cmd'] == 'setwidepth'):
+ vars.widepth = int(msg['data'])
+ emit('from_server', {'cmd': 'setlabelwidepth', 'data': msg['data']})
+ settingschanged()
+ elif(msg['cmd'] == 'setuseprompt'):
+ vars.useprompt = msg['data']
+ settingschanged()
+ elif(msg['cmd'] == 'importwi'):
+ wiimportrequest()
#==================================================================#
# Send start message and tell Javascript to set UI state
#==================================================================#
def setStartState():
- emit('from_server', {'cmd': 'updatescreen', 'data': 'Welcome to KoboldAI Client! You are running '+vars.model+'.
Please load a game or enter a prompt below to begin!'})
+ txt = "Welcome to KoboldAI Client! You are running "+vars.model+".
"
+ if(not vars.noai):
+ txt = txt + "Please load a game or enter a prompt below to begin!"
+ else:
+ txt = txt + "Please load or import a story to read. There is no AI in this mode."
+ emit('from_server', {'cmd': 'updatescreen', 'data': txt})
emit('from_server', {'cmd': 'setgamestate', 'data': 'start'})
#==================================================================#
@@ -563,6 +577,9 @@ def savesettings():
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
+ js["numseqs"] = vars.numseqs
+ js["widepth"] = vars.widepth
+ js["useprompt"] = vars.useprompt
# Write it
file = open("client.settings", "w")
@@ -599,6 +616,12 @@ def loadsettings():
vars.ikgen = js["ikgen"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
+ if("numseqs" in js):
+ vars.numseqs = js["numseqs"]
+ if("widepth" in js):
+ vars.widepth = js["widepth"]
+ if("useprompt" in js):
+ vars.useprompt = js["useprompt"]
file.close()
@@ -628,10 +651,13 @@ def actionsubmit(data):
vars.gamestarted = True
# Save this first action as the prompt
vars.prompt = data
- # Clear the startup text from game screen
- emit('from_server', {'cmd': 'updatescreen', 'data': 'Please wait, generating story...'})
-
- calcsubmit(data) # Run the first action through the generator
+ if(not vars.noai):
+ # Clear the startup text from game screen
+ emit('from_server', {'cmd': 'updatescreen', 'data': 'Please wait, generating story...'})
+ calcsubmit(data) # Run the first action through the generator
+ else:
+ refresh_story()
+ set_aibusy(0)
else:
# Dont append submission if it's a blank/continue action
if(data != ""):
@@ -640,8 +666,39 @@ def actionsubmit(data):
# Store the result in the Action log
vars.actions.append(data)
- # Off to the tokenizer!
- calcsubmit(data)
+ if(not vars.noai):
+ # Off to the tokenizer!
+ calcsubmit(data)
+ else:
+ refresh_story()
+ set_aibusy(0)
+
+#==================================================================#
+#
+#==================================================================#
+def actionretry(data):
+ if(vars.noai):
+ emit('from_server', {'cmd': 'errmsg', 'data': "Retry function unavailable in Read Only mode."})
+ return
+ if(vars.aibusy):
+ return
+ set_aibusy(1)
+ # Remove last action if possible and resubmit
+ if(len(vars.actions) > 0):
+ vars.actions.pop()
+ refresh_story()
+ calcsubmit('')
+
+#==================================================================#
+#
+#==================================================================#
+def actionback():
+ if(vars.aibusy):
+ return
+ # Remove last index of actions and refresh game screen
+ if(len(vars.actions) > 0):
+ vars.actions.pop()
+ refresh_story()
#==================================================================#
# Take submitted text and build the text to be given to generator
@@ -684,7 +741,10 @@ def calcsubmit(txt):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
- budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
+ if(vars.useprompt):
+ budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
+ else:
+ budget = vars.max_length - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
@@ -717,6 +777,7 @@ def calcsubmit(txt):
else:
count = budget * -1
tokens = acttkns[count:] + tokens
+ budget = 0
break
# Inject Author's Note if we've reached the desired depth
@@ -724,6 +785,14 @@ def calcsubmit(txt):
if(anotetxt != ""):
tokens = anotetkns + tokens # A.N. len already taken from bdgt
anoteadded = True
+
+ # If we're not using the prompt every time and there's still budget left,
+ # add some prompt.
+ if(not vars.useprompt):
+ if(budget > 0):
+ prompttkns = prompttkns[-budget:]
+ else:
+ prompttkns = []
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
@@ -759,12 +828,15 @@ def calcsubmit(txt):
# For InferKit web API
else:
-
# Check if we have the action depth to hit our A.N. depth
if(anotetxt != "" and actionlen < vars.andepth):
forceanote = True
- budget = vars.ikmax - len(vars.prompt) - len(anotetxt) - len(mem) - len(winfo) - 1
+ if(vars.useprompt):
+ budget = vars.ikmax - len(vars.prompt) - len(anotetxt) - len(mem) - len(winfo) - 1
+ else:
+ budget = vars.ikmax - len(anotetxt) - len(mem) - len(winfo) - 1
+
subtxt = ""
for n in range(actionlen):
@@ -777,8 +849,18 @@ def calcsubmit(txt):
else:
count = budget * -1
subtxt = vars.actions[(-1-n)][count:] + subtxt
+ budget = 0
break
+ # If we're not using the prompt every time and there's still budget left,
+ # add some prompt.
+ prompt = vars.prompt
+ if(not vars.useprompt):
+ if(budget > 0):
+ prompt = vars.prompt[-budget:]
+ else:
+ prompt = ""
+
# Inject Author's Note if we've reached the desired depth
if(n == vars.andepth-1):
if(anotetxt != ""):
@@ -788,11 +870,11 @@ def calcsubmit(txt):
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
- subtxt = mem + winfo + anotetxt + vars.prompt + subtxt
+ subtxt = mem + winfo + anotetxt + prompt + subtxt
else:
- subtxt = mem + winfo + vars.prompt + subtxt
+ subtxt = mem + winfo + prompt + subtxt
else:
- subtxt = mem + winfo + vars.prompt + subtxt
+ subtxt = mem + winfo + prompt + subtxt
# Send it!
ikrequest(subtxt)
@@ -811,26 +893,30 @@ def generate(txt, min, max):
torch.cuda.empty_cache()
# Submit input text to generator
- genout = generator(
- txt,
- do_sample=True,
- min_length=min,
- max_length=max,
- repetition_penalty=vars.rep_pen,
- top_p=vars.top_p,
- temperature=vars.temp,
- bad_words_ids=vars.badwordsids,
- use_cache=True
- )[0]["generated_text"]
- print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
+ try:
+ genout = generator(
+ txt,
+ do_sample=True,
+ min_length=min,
+ max_length=max,
+ repetition_penalty=vars.rep_pen,
+ top_p=vars.top_p,
+ temperature=vars.temp,
+ bad_words_ids=vars.badwordsids,
+ use_cache=True,
+ return_full_text=False,
+ num_return_sequences=vars.numseqs
+ )
+ except Exception as e:
+ emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'})
+ print("{0}{1}{2}".format(colors.RED, e, colors.END))
+ set_aibusy(0)
+ return
- # Format output before continuing
- genout = applyoutputformatting(getnewcontent(genout))
-
- # Add formatted text to Actions array and refresh the game screen
- vars.actions.append(genout)
- refresh_story()
- emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
+ if(len(genout) == 1):
+ genresult(genout[0]["generated_text"])
+ else:
+ genselect(genout)
# Clear CUDA cache again if using GPU
if(vars.hascuda and vars.usegpu):
@@ -838,6 +924,52 @@ def generate(txt, min, max):
set_aibusy(0)
+#==================================================================#
+# Deal with a single return sequence from generate()
+#==================================================================#
+def genresult(genout):
+ print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
+
+ # Format output before continuing
+ genout = applyoutputformatting(genout)
+
+ # Add formatted text to Actions array and refresh the game screen
+ vars.actions.append(genout)
+ refresh_story()
+ emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
+
+#==================================================================#
+# Send generator sequences to the UI for selection
+#==================================================================#
+def genselect(genout):
+ i = 0
+ for result in genout:
+ # Apply output formatting rules to sequences
+ result["generated_text"] = applyoutputformatting(result["generated_text"])
+ print("{0}[Result {1}]\n{2}{3}".format(colors.CYAN, i, result["generated_text"], colors.END))
+ i += 1
+
+ # Store sequences in memory until selection is made
+ vars.genseqs = genout
+
+ # Send sequences to UI for selection
+ emit('from_server', {'cmd': 'genseqs', 'data': genout})
+
+ # Refresh story for any input text
+ refresh_story()
+
+#==================================================================#
+# Send selected sequence to action log and refresh UI
+#==================================================================#
+def selectsequence(n):
+ if(len(vars.genseqs) == 0):
+ return
+ vars.actions.append(vars.genseqs[int(n)]["generated_text"])
+ refresh_story()
+ emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
+ emit('from_server', {'cmd': 'hidegenseqs', 'data': ''})
+ vars.genseqs = []
+
#==================================================================#
# Send transformers-style request to ngrok/colab host
#==================================================================#
@@ -855,7 +987,9 @@ def sendtocolab(txt, min, max):
'max': max,
'rep_pen': vars.rep_pen,
'temperature': vars.temp,
- 'top_p': vars.top_p
+ 'top_p': vars.top_p,
+ 'numseqs': vars.numseqs,
+ 'retfultxt': False
}
# Create request
@@ -866,16 +1000,30 @@ def sendtocolab(txt, min, max):
# Deal with the response
if(req.status_code == 200):
- genout = req.json()["data"]["text"]
- print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
+ js = req.json()["data"]
+
+ # Try to be backwards compatible with outdated colab
+ if("text" in js):
+ genout = [getnewcontent(js["text"])]
+ else:
+ genout = js["seqs"]
+
+ if(len(genout) == 1):
+ genresult(genout[0])
+ else:
+ # Convert torch output format to transformers
+ seqs = []
+ for seq in genout:
+ seqs.append({"generated_text": seq})
+ genselect(seqs)
# Format output before continuing
- genout = applyoutputformatting(getnewcontent(genout))
+ #genout = applyoutputformatting(getnewcontent(genout))
# Add formatted text to Actions array and refresh the game screen
- vars.actions.append(genout)
- refresh_story()
- emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
+ #vars.actions.append(genout)
+ #refresh_story()
+ #emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
set_aibusy(0)
else:
@@ -962,12 +1110,15 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})
+ emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs})
else:
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
emit('from_server', {'cmd': 'updateikgen', 'data': vars.ikgen})
emit('from_server', {'cmd': 'updateanotedepth', 'data': vars.andepth})
+ emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth})
+ emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt})
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]})
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]})
@@ -1378,6 +1529,8 @@ def saveRequest(savpath):
file.write(json.dumps(js, indent=3))
finally:
file.close()
+
+ print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END))
#==================================================================#
# Load a saved story via file browser
@@ -1442,6 +1595,8 @@ def loadRequest(loadpath):
sendwi()
refresh_story()
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
+ emit('from_server', {'cmd': 'hidegenseqs', 'data': ''})
+ print("{0}Story loaded from {1}!{2}".format(colors.GREEN, path.basename(loadpath), colors.END))
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
@@ -1554,6 +1709,7 @@ def importgame():
sendwi()
refresh_story()
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
+ emit('from_server', {'cmd': 'hidegenseqs', 'data': ''})
#==================================================================#
# Import an aidg.club prompt and start a new game with it.
@@ -1595,6 +1751,34 @@ def importAidgRequest(id):
refresh_story()
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
+#==================================================================#
+# Import World Info JSON file
+#==================================================================#
+def wiimportrequest():
+ importpath = fileops.getloadpath(vars.savedir, "Select World Info File", [("Json", "*.json")])
+ if(importpath):
+ file = open(importpath, "rb")
+ js = json.load(file)
+ if(len(js) > 0):
+ # If the most recent WI entry is blank, remove it.
+ if(not vars.worldinfo[-1]["init"]):
+ del vars.worldinfo[-1]
+ # Now grab the new stuff
+ num = len(vars.worldinfo)
+ for wi in js:
+ vars.worldinfo.append({
+ "key": wi["keys"],
+ "content": wi["entry"],
+ "num": num,
+ "init": True
+ })
+ num += 1
+
+ print("{0}".format(vars.worldinfo[0]))
+
+ # Refresh game screen
+ sendwi()
+
#==================================================================#
# Starts a new story
#==================================================================#
diff --git a/gensettings.py b/gensettings.py
index 84c180f9..8f6a67a0 100644
--- a/gensettings.py
+++ b/gensettings.py
@@ -52,6 +52,39 @@ gensettingstf = [{
"step": 8,
"default": 512,
"tooltip": "Max number of tokens of context to submit to the AI for sampling. Make sure this is higher than Amount to Generate. Higher values increase VRAM/RAM usage."
+ },
+ {
+ "uitype": "slider",
+ "unit": "int",
+ "label": "Gens Per Action",
+ "id": "setnumseq",
+ "min": 1,
+ "max": 5,
+ "step": 1,
+ "default": 1,
+ "tooltip": "Number of results to generate per submission. Increases VRAM/RAM usage."
+ },
+ {
+ "uitype": "slider",
+ "unit": "int",
+ "label": "W Info Depth",
+ "id": "setwidepth",
+ "min": 1,
+ "max": 5,
+ "step": 1,
+ "default": 1,
+ "tooltip": "Number of historic actions to scan for W Info keys."
+ },
+ {
+ "uitype": "toggle",
+ "unit": "bool",
+ "label": "Always Add Prompt",
+ "id": "setuseprompt",
+ "min": 0,
+ "max": 1,
+ "step": 1,
+ "default": 1,
+ "tooltip": "Whether the prompt should be sent in the context of every action."
}]
gensettingsik =[{
@@ -86,6 +119,28 @@ gensettingsik =[{
"step": 2,
"default": 200,
"tooltip": "Number of characters the AI should generate."
+ },
+ {
+ "uitype": "slider",
+ "unit": "int",
+ "label": "W Info Depth",
+ "id": "setwidepth",
+ "min": 1,
+ "max": 5,
+ "step": 1,
+ "default": 1,
+ "tooltip": "Number of historic actions to scan for W Info keys."
+ },
+ {
+ "uitype": "toggle",
+ "unit": "bool",
+ "label": "Always Add Prompt",
+ "id": "setuseprompt",
+ "min": 0,
+ "max": 1,
+ "step": 1,
+ "default": 1,
+ "tooltip": "Whether the prompt should be sent in the context of every action."
}]
formatcontrols = [{
diff --git a/static/application.js b/static/application.js
index 54a6debd..e5ebc806 100644
--- a/static/application.js
+++ b/static/application.js
@@ -13,6 +13,7 @@ var button_saveas;
var button_savetofile;
var button_load;
var button_import;
+var button_importwi;
var button_impaidg;
var button_settings;
var button_format;
@@ -54,6 +55,8 @@ var load_close;
var nspopup;
var ns_accept;
var ns_close;
+var seqselmenu;
+var seqselcontents;
// Key states
var shift_down = false;
@@ -69,36 +72,51 @@ var formatcount = 0;
function addSetting(ob) {
// Add setting block to Settings Menu
- settings_menu.append("