mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-02 18:46:48 +01:00
Settings menu modularized.
Help text added to settings items. Settings now saved to client file when changed. Separated transformers settings and InferKit settings. Reorganized model select list.
This commit is contained in:
parent
ade5be39fb
commit
d632976fbf
289
aiserver.py
289
aiserver.py
@ -5,55 +5,61 @@
|
||||
#==================================================================#
|
||||
|
||||
from os import path, getcwd
|
||||
from tkinter import filedialog, messagebox
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
import json
|
||||
import torch
|
||||
|
||||
import fileops
|
||||
import gensettings
|
||||
from utils import debounce
|
||||
|
||||
#==================================================================#
|
||||
# Variables & Storage
|
||||
#==================================================================#
|
||||
|
||||
# Terminal tags for colored text
|
||||
class colors:
|
||||
HEADER = '\033[95m'
|
||||
OKBLUE = '\033[94m'
|
||||
OKCYAN = '\033[96m'
|
||||
OKGREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
PURPLE = '\033[95m'
|
||||
BLUE = '\033[94m'
|
||||
CYAN = '\033[96m'
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
RED = '\033[91m'
|
||||
END = '\033[0m'
|
||||
UNDERLINE = '\033[4m'
|
||||
|
||||
# Transformers models
|
||||
# AI models
|
||||
modellist = [
|
||||
["InferKit API (requires API key)", "InferKit", ""],
|
||||
["GPT Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "8GB"],
|
||||
["GPT Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "16GB"],
|
||||
["GPT-2", "gpt2", "1.2GB"],
|
||||
["GPT-2 Med", "gpt2-medium", "2GB"],
|
||||
["GPT-2 Large", "gpt2-large", "16GB"],
|
||||
["GPT-2 XL", "gpt2-xl", "16GB"],
|
||||
["InferKit API (requires API key)", "InferKit", ""],
|
||||
["Custom Neo (eg Neo-horni)", "NeoCustom", ""],
|
||||
["Custom GPT-2 (eg CloverEdition)", "GPT2Custom", ""]
|
||||
]
|
||||
|
||||
# Variables
|
||||
class vars:
|
||||
lastact = "" # The last action submitted to the generator
|
||||
lastact = "" # The last action submitted to the generator
|
||||
model = ""
|
||||
noai = False # Runs the script without starting up the transformers pipeline
|
||||
aibusy = False # Stops submissions while the AI is working
|
||||
max_length = 500 # Maximum number of tokens to submit per action
|
||||
genamt = 60 # Amount of text for each action to generate
|
||||
rep_pen = 1.0 # Default generator repetition_penalty
|
||||
temp = 0.9 # Default generator temperature
|
||||
top_p = 1.0 # Default generator top_p
|
||||
max_length = 1024 # Maximum number of tokens to submit per action
|
||||
ikmax = 3000 # Maximum number of characters to submit to InferKit
|
||||
genamt = 60 # Amount of text for each action to generate
|
||||
ikgen = 200 # Number of characters for InferKit to generate
|
||||
rep_pen = 1.0 # Default generator repetition_penalty
|
||||
temp = 1.0 # Default generator temperature
|
||||
top_p = 1.0 # Default generator top_p
|
||||
gamestarted = False
|
||||
prompt = ""
|
||||
memory = ""
|
||||
authornote = ""
|
||||
andepth = 3 # How far back in history to append author's note
|
||||
andepth = 3 # How far back in history to append author's note
|
||||
actions = []
|
||||
mode = "play" # Whether the interface is in play, memory, or edit mode
|
||||
editln = 0 # Which line was last selected in Edit Mode
|
||||
@ -86,27 +92,21 @@ def getModelSelection():
|
||||
if(modelsel.isnumeric() and int(modelsel) > 0 and int(modelsel) <= len(modellist)):
|
||||
vars.model = modellist[int(modelsel)-1][1]
|
||||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.FAIL, colors.ENDC))
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
||||
# If custom model was selected, get the filesystem location and store it
|
||||
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
||||
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.OKCYAN, colors.ENDC))
|
||||
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
|
||||
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
path = filedialog.askdirectory(
|
||||
initialdir=getcwd(),
|
||||
title="Select Model Folder",
|
||||
)
|
||||
root.destroy()
|
||||
modpath = fileops.getdirpath(getcwd(), "Select Model Folder")
|
||||
|
||||
if(path != None and path != ""):
|
||||
if(modpath):
|
||||
# Save directory to vars
|
||||
vars.custmodpth = path
|
||||
vars.custmodpth = modpath
|
||||
else:
|
||||
# Print error and retry model selection
|
||||
print("{0}Model select cancelled!{1}".format(colors.FAIL, colors.ENDC))
|
||||
print("{0}Select an AI model to continue:{1}\n".format(colors.OKCYAN, colors.ENDC))
|
||||
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
|
||||
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection()
|
||||
|
||||
#==================================================================#
|
||||
@ -114,20 +114,20 @@ def getModelSelection():
|
||||
#==================================================================#
|
||||
|
||||
# Test for GPU support
|
||||
print("{0}Looking for GPU support...{1}".format(colors.HEADER, colors.ENDC), end="")
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
if(vars.hascuda):
|
||||
print("{0}FOUND!{1}".format(colors.OKGREEN, colors.ENDC))
|
||||
print("{0}FOUND!{1}".format(colors.GREEN, colors.END))
|
||||
else:
|
||||
print("{0}NOT FOUND!{1}".format(colors.WARNING, colors.ENDC))
|
||||
print("{0}NOT FOUND!{1}".format(colors.YELLOW, colors.END))
|
||||
|
||||
# Select a model to run
|
||||
print("{0}Welcome to the KoboldAI Client!\nSelect an AI model to continue:{1}\n".format(colors.OKCYAN, colors.ENDC))
|
||||
print("{0}Welcome to the KoboldAI Client!\nSelect an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection()
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(vars.model != "InferKit" and vars.hascuda):
|
||||
print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.OKCYAN, colors.ENDC))
|
||||
print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
|
||||
print(" 1 - GPU\n 2 - CPU\n")
|
||||
genselected = False
|
||||
while(genselected == False):
|
||||
@ -142,13 +142,13 @@ if(vars.model != "InferKit" and vars.hascuda):
|
||||
vars.usegpu = False
|
||||
genselected = True
|
||||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.FAIL, colors.ENDC))
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
||||
# Ask for API key if InferKit was selected
|
||||
if(vars.model == "InferKit"):
|
||||
if(not path.exists("client.settings")):
|
||||
# If the client settings file doesn't exist, create it
|
||||
print("{0}Please enter your InferKit API key:{1}\n".format(colors.OKCYAN, colors.ENDC))
|
||||
print("{0}Please enter your InferKit API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.apikey = input("Key> ")
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
@ -157,10 +157,25 @@ if(vars.model == "InferKit"):
|
||||
finally:
|
||||
file.close()
|
||||
else:
|
||||
# Otherwise open it up and get the key
|
||||
# Otherwise open it up
|
||||
file = open("client.settings", "r")
|
||||
vars.apikey = json.load(file)["apikey"]
|
||||
file.close()
|
||||
# Check if API key exists
|
||||
js = json.load(file)
|
||||
if(js["apikey"] != ""):
|
||||
# API key exists, grab it and close the file
|
||||
vars.apikey = js["apikey"]
|
||||
file.close()
|
||||
else:
|
||||
# Get API key, add it to settings object, and write it to disk
|
||||
print("{0}Please enter your InferKit API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.apikey = input("Key> ")
|
||||
js["apikey"] = vars.apikey
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
# Set logging level to reduce chatter from Flask
|
||||
import logging
|
||||
@ -168,18 +183,18 @@ log = logging.getLogger('werkzeug')
|
||||
log.setLevel(logging.ERROR)
|
||||
|
||||
# Start flask & SocketIO
|
||||
print("{0}Initializing Flask... {1}".format(colors.HEADER, colors.ENDC), end="")
|
||||
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
|
||||
from flask import Flask, render_template
|
||||
from flask_socketio import SocketIO, emit
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET KEY'] = 'secret!'
|
||||
socketio = SocketIO(app)
|
||||
print("{0}OK!{1}".format(colors.OKGREEN, colors.ENDC))
|
||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
if(vars.model != "InferKit"):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.HEADER, colors.ENDC))
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
|
||||
|
||||
# If custom GPT Neo model was chosen
|
||||
@ -209,14 +224,10 @@ if(vars.model != "InferKit"):
|
||||
else:
|
||||
generator = pipeline('text-generation', model=vars.model)
|
||||
|
||||
print("{0}OK! {1} pipeline created!{2}".format(colors.OKGREEN, vars.model, colors.ENDC))
|
||||
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
|
||||
else:
|
||||
# Import requests library for HTTPS calls
|
||||
import requests
|
||||
|
||||
# Set generator variables to match InferKit's capabilities
|
||||
vars.max_length = 3000
|
||||
vars.genamt = 200
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
@ -231,14 +242,16 @@ def index():
|
||||
#==================================================================#
|
||||
@socketio.on('connect')
|
||||
def do_connect():
|
||||
print("{0}Client connected!{1}".format(colors.OKGREEN, colors.ENDC))
|
||||
print("{0}Client connected!{1}".format(colors.GREEN, colors.END))
|
||||
emit('from_server', {'cmd': 'connected'})
|
||||
if(not vars.gamestarted):
|
||||
setStartState()
|
||||
sendsettings()
|
||||
refresh_settings()
|
||||
else:
|
||||
# Game in session, send current game data and ready state to browser
|
||||
refresh_story()
|
||||
sendsettings()
|
||||
refresh_settings()
|
||||
if(vars.mode == "play"):
|
||||
if(not vars.aibusy):
|
||||
@ -255,7 +268,7 @@ def do_connect():
|
||||
#==================================================================#
|
||||
@socketio.on('message')
|
||||
def get_message(msg):
|
||||
print("{0}Data recieved:{1}{2}".format(colors.OKGREEN, msg, colors.ENDC))
|
||||
print("{0}Data recieved:{1}{2}".format(colors.GREEN, msg, colors.END))
|
||||
# Submit action
|
||||
if(msg['cmd'] == 'submit'):
|
||||
if(vars.mode == "play"):
|
||||
@ -305,17 +318,29 @@ def get_message(msg):
|
||||
elif(msg['cmd'] == 'newgame'):
|
||||
newGameRequest()
|
||||
elif(msg['cmd'] == 'settemp'):
|
||||
vars.temperature = float(msg['data'])
|
||||
vars.temp = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltemp', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'settopp'):
|
||||
vars.top_p = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'setreppen'):
|
||||
vars.rep_pen = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'setoutput'):
|
||||
vars.genamt = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'settknmax'):
|
||||
vars.max_length = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltknmax', 'data': msg['data']})
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'setikgen'):
|
||||
vars.ikgen = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelikgen', 'data': msg['data']})
|
||||
settingschanged()
|
||||
# Author's Note field update
|
||||
elif(msg['cmd'] == 'anote'):
|
||||
anotesubmit(msg['data'])
|
||||
@ -323,6 +348,7 @@ def get_message(msg):
|
||||
elif(msg['cmd'] == 'anotedepth'):
|
||||
vars.andepth = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelanotedepth', 'data': msg['data']})
|
||||
settingschanged()
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
@ -331,6 +357,71 @@ def setStartState():
|
||||
emit('from_server', {'cmd': 'updatescreen', 'data': '<span>Welcome to <span class="color_cyan">KoboldAI Client</span>! You are running <span class="color_green">'+vars.model+'</span>.<br/>Please load a game or enter a prompt below to begin!</span>'})
|
||||
emit('from_server', {'cmd': 'setgamestate', 'data': 'start'})
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
def sendsettings():
|
||||
# Send settings for selected AI type
|
||||
if(vars.model != "InferKit"):
|
||||
for set in gensettings.gensettingstf:
|
||||
emit('from_server', {'cmd': 'addsetting', 'data': set})
|
||||
else:
|
||||
for set in gensettings.gensettingsik:
|
||||
emit('from_server', {'cmd': 'addsetting', 'data': set})
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
def savesettings():
|
||||
# Build json to write
|
||||
js = {}
|
||||
js["apikey"] = vars.apikey
|
||||
js["andepth"] = vars.andepth
|
||||
js["temp"] = vars.temp
|
||||
js["top_p"] = vars.top_p
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
js["ikgen"] = vars.ikgen
|
||||
|
||||
# Write it
|
||||
file = open("client.settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
def loadsettings():
|
||||
if(path.exists("client.settings")):
|
||||
# Read file contents into JSON object
|
||||
file = open("client.settings", "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
#for set in js:
|
||||
# vars[set] = js[set]
|
||||
vars.apikey = js["apikey"]
|
||||
vars.andepth = js["andepth"]
|
||||
vars.temp = js["temp"]
|
||||
vars.top_p = js["top_p"]
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
vars.genamt = js["genamt"]
|
||||
vars.max_length = js["max_length"]
|
||||
vars.ikgen = js["ikgen"]
|
||||
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
@debounce(2)
|
||||
def settingschanged():
|
||||
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
|
||||
savesettings()
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
@ -439,7 +530,7 @@ def calcsubmit(txt):
|
||||
if(anotetxt != "" and actionlen < vars.andepth):
|
||||
forceanote = True
|
||||
|
||||
budget = vars.max_length - len(vars.prompt) - len(anotetxt) - len(vars.memory) - 1
|
||||
budget = vars.ikmax - len(vars.prompt) - len(anotetxt) - len(vars.memory) - 1
|
||||
subtxt = ""
|
||||
for n in range(actionlen):
|
||||
|
||||
@ -481,7 +572,7 @@ def calcsubmit(txt):
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
def generate(txt, min, max):
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.WARNING, min, max, txt, colors.ENDC))
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
|
||||
|
||||
# Clear CUDA cache if using GPU
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
@ -496,7 +587,7 @@ def generate(txt, min, max):
|
||||
repetition_penalty=vars.rep_pen,
|
||||
temperature=vars.temp
|
||||
)[0]["generated_text"]
|
||||
print("{0}{1}{2}".format(colors.OKCYAN, genout, colors.ENDC))
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
vars.actions.append(getnewcontent(genout))
|
||||
refresh_story()
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
|
||||
@ -540,11 +631,18 @@ def refresh_story():
|
||||
# Sends the current generator settings to the Game Menu
|
||||
#==================================================================#
|
||||
def refresh_settings():
|
||||
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
|
||||
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
|
||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
|
||||
emit('from_server', {'cmd': 'updatanotedepth', 'data': vars.andepth})
|
||||
if(vars.model != "InferKit"):
|
||||
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
|
||||
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen})
|
||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt})
|
||||
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length})
|
||||
else:
|
||||
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp})
|
||||
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p})
|
||||
emit('from_server', {'cmd': 'updateikgen', 'data': vars.ikgen})
|
||||
|
||||
emit('from_server', {'cmd': 'updateanotedepth', 'data': vars.andepth})
|
||||
|
||||
#==================================================================#
|
||||
# Sets the logical and display states for the AI Busy condition
|
||||
@ -637,12 +735,12 @@ def anotesubmit(data):
|
||||
#==================================================================#
|
||||
def ikrequest(txt):
|
||||
# Log request to console
|
||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.WARNING, len(txt), txt, colors.ENDC))
|
||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
|
||||
|
||||
# Build request JSON data
|
||||
reqdata = {
|
||||
'forceNoEnd': True,
|
||||
'length': vars.genamt,
|
||||
'length': vars.ikgen,
|
||||
'prompt': {
|
||||
'isContinuation': False,
|
||||
'text': txt
|
||||
@ -665,7 +763,7 @@ def ikrequest(txt):
|
||||
# Deal with the response
|
||||
if(req.status_code == 200):
|
||||
genout = req.json()["data"]["text"]
|
||||
print("{0}{1}{2}".format(colors.OKCYAN, genout, colors.ENDC))
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
vars.actions.append(genout)
|
||||
refresh_story()
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': len(vars.actions)})
|
||||
@ -697,34 +795,25 @@ def exitModes():
|
||||
# Save the story to a file
|
||||
#==================================================================#
|
||||
def saveRequest():
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
path = filedialog.asksaveasfile(
|
||||
initialdir=vars.savedir,
|
||||
title="Save Story As",
|
||||
filetypes = [("Json", "*.json")]
|
||||
)
|
||||
root.destroy()
|
||||
savpath = fileops.getsavepath(vars.savedir, "Save Story As", [("Json", "*.json")])
|
||||
|
||||
if(path != None and path != ''):
|
||||
if(savpath):
|
||||
# Leave Edit/Memory mode before continuing
|
||||
exitModes()
|
||||
|
||||
# Save path for future saves
|
||||
vars.savedir = path
|
||||
vars.savedir = savpath
|
||||
|
||||
# Build json to write
|
||||
js = {}
|
||||
#js["maxlegth"] = vars.max_length # This causes problems when switching to/from InfraKit
|
||||
#js["genamt"] = vars.genamt
|
||||
#js["rep_pen"] = vars.rep_pen
|
||||
#js["temp"] = vars.temp
|
||||
js["gamestarted"] = vars.gamestarted
|
||||
js["prompt"] = vars.prompt
|
||||
js["memory"] = vars.memory
|
||||
js["authorsnote"] = vars.authornote
|
||||
js["actions"] = vars.actions
|
||||
#js["savedir"] = path.name # For privacy, don't include savedir in save file
|
||||
|
||||
# Write it
|
||||
file = open(path.name, "w")
|
||||
file = open(savpath, "w")
|
||||
try:
|
||||
file.write(json.dumps(js))
|
||||
finally:
|
||||
@ -733,38 +822,31 @@ def saveRequest():
|
||||
#==================================================================#
|
||||
# Load a stored story from a file
|
||||
#==================================================================#
|
||||
def loadRequest():
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
path = filedialog.askopenfilename(
|
||||
initialdir=vars.savedir,
|
||||
title="Select Story File",
|
||||
filetypes = [("Json", "*.json")]
|
||||
)
|
||||
root.destroy()
|
||||
def loadRequest():
|
||||
loadpath = fileops.getloadpath(vars.savedir, "Select Story File", [("Json", "*.json")])
|
||||
|
||||
if(path != None and path != ''):
|
||||
if(loadpath):
|
||||
# Leave Edit/Memory mode before continuing
|
||||
exitModes()
|
||||
|
||||
# Read file contents into JSON object
|
||||
file = open(path, "r")
|
||||
js = json.load(file)
|
||||
file = open(loadpath, "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
#vars.max_length = js["maxlegth"] # This causes problems when switching to/from InfraKit
|
||||
#vars.genamt = js["genamt"]
|
||||
#vars.rep_pen = js["rep_pen"]
|
||||
#vars.temp = js["temp"]
|
||||
vars.gamestarted = js["gamestarted"]
|
||||
vars.prompt = js["prompt"]
|
||||
vars.memory = js["memory"]
|
||||
vars.actions = js["actions"]
|
||||
#vars.savedir = js["savedir"] # For privacy, don't include savedir in save file
|
||||
|
||||
# Try not to break older save files
|
||||
if("authorsnote" in js):
|
||||
vars.authornote = js["authorsnote"]
|
||||
else:
|
||||
vars.authornote = ""
|
||||
|
||||
file.close()
|
||||
|
||||
# Refresh game screen
|
||||
refresh_story()
|
||||
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
|
||||
@ -776,25 +858,32 @@ def newGameRequest():
|
||||
# Ask for confirmation
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
confirm = messagebox.askquestion("Confirm New Game", "Really start new Story?")
|
||||
confirm = tk.messagebox.askquestion("Confirm New Game", "Really start new Story?")
|
||||
root.destroy()
|
||||
|
||||
if(confirm == "yes"):
|
||||
# Leave Edit/Memory mode before continuing
|
||||
exitModes()
|
||||
|
||||
# Clear vars values
|
||||
vars.gamestarted = False
|
||||
vars.prompt = ""
|
||||
vars.memory = ""
|
||||
vars.actions = []
|
||||
vars.savedir = getcwd()+"\stories\\newstory.json"
|
||||
vars.savedir = getcwd()+"\stories"
|
||||
vars.authornote = ""
|
||||
|
||||
# Refresh game screen
|
||||
setStartState()
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||
# Final startup commands to launch Flask app
|
||||
#==================================================================#
|
||||
if __name__ == "__main__":
|
||||
print("{0}Server started!\rYou may now connect with a browser at http://127.0.0.1:5000/{1}".format(colors.OKGREEN, colors.ENDC))
|
||||
# Load settings from client.settings
|
||||
loadsettings()
|
||||
|
||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||
print("{0}Server started!\rYou may now connect with a browser at http://127.0.0.1:5000/{1}".format(colors.GREEN, colors.END))
|
||||
socketio.run(app)
|
||||
|
52
fileops.py
Normal file
52
fileops.py
Normal file
@ -0,0 +1,52 @@
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
|
||||
#==================================================================#
|
||||
# Generic Method for prompting for file path
|
||||
#==================================================================#
|
||||
def getsavepath(dir, title, types):
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
path = tk.filedialog.asksaveasfile(
|
||||
initialdir=dir,
|
||||
title=title,
|
||||
filetypes = types
|
||||
)
|
||||
root.destroy()
|
||||
if(path != "" and path != None):
|
||||
return path.name
|
||||
else:
|
||||
return None
|
||||
|
||||
#==================================================================#
|
||||
# Generic Method for prompting for file path
|
||||
#==================================================================#
|
||||
def getloadpath(dir, title, types):
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
path = tk.filedialog.askopenfilename(
|
||||
initialdir=dir,
|
||||
title=title,
|
||||
filetypes = types
|
||||
)
|
||||
root.destroy()
|
||||
if(path != "" and path != None):
|
||||
return path
|
||||
else:
|
||||
return None
|
||||
|
||||
#==================================================================#
|
||||
# Generic Method for prompting for directory path
|
||||
#==================================================================#
|
||||
def getdirpath(dir, title):
|
||||
root = tk.Tk()
|
||||
root.attributes("-topmost", True)
|
||||
path = filedialog.askdirectory(
|
||||
initialdir=dir,
|
||||
title=title
|
||||
)
|
||||
root.destroy()
|
||||
if(path != "" and path != None):
|
||||
return path
|
||||
else:
|
||||
return None
|
89
gensettings.py
Normal file
89
gensettings.py
Normal file
@ -0,0 +1,89 @@
|
||||
gensettingstf = [{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Temperature",
|
||||
"id": "settemp",
|
||||
"min": 0.1,
|
||||
"max": 2.0,
|
||||
"step": 0.05,
|
||||
"default": 1.0,
|
||||
"tooltip": "Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Top p Sampling",
|
||||
"id": "settopp",
|
||||
"min": 0.1,
|
||||
"max": 1.0,
|
||||
"step": 0.05,
|
||||
"default": 1.0,
|
||||
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Repetition Penalty",
|
||||
"id": "setreppen",
|
||||
"min": 1.0,
|
||||
"max": 2.0,
|
||||
"step": 0.05,
|
||||
"default": 1.0,
|
||||
"tooltip": "Used to penalize words that were already generated or belong to the context"
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Amount to Generate",
|
||||
"id": "setoutput",
|
||||
"min": 10,
|
||||
"max": 500,
|
||||
"step": 2,
|
||||
"default": 60,
|
||||
"tooltip": "Number of tokens the AI should generate. Higher numbers will take longer to generate."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Max Tokens",
|
||||
"id": "settknmax",
|
||||
"min": 512,
|
||||
"max": 2048,
|
||||
"step": 8,
|
||||
"default": 1024,
|
||||
"tooltip": "Number of tokens of context to submit to the AI for sampling."
|
||||
}]
|
||||
|
||||
gensettingsik =[{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Temperature",
|
||||
"id": "settemp",
|
||||
"min": 0.1,
|
||||
"max": 2.0,
|
||||
"step": 0.05,
|
||||
"default": 1.0,
|
||||
"tooltip": "Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Top p Sampling",
|
||||
"id": "settopp",
|
||||
"min": 0.1,
|
||||
"max": 1.0,
|
||||
"step": 0.05,
|
||||
"default": 1.0,
|
||||
"tooltip": "Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Amount to Generate",
|
||||
"id": "setikgen",
|
||||
"min": 50,
|
||||
"max": 3000,
|
||||
"step": 2,
|
||||
"default": 200,
|
||||
"tooltip": "Number of characters the AI should generate."
|
||||
}]
|
@ -6,6 +6,7 @@
|
||||
var socket;
|
||||
|
||||
// UI references for jQuery
|
||||
var connect_status;
|
||||
var button_newgame;
|
||||
var button_save;
|
||||
var button_load;
|
||||
@ -19,14 +20,7 @@ var button_delete;
|
||||
var game_text;
|
||||
var input_text;
|
||||
var message_text;
|
||||
var setting_temp;
|
||||
var setting_topp;
|
||||
var setting_reppen;
|
||||
var setting_outlen;
|
||||
var label_temp;
|
||||
var label_topp;
|
||||
var label_reppen;
|
||||
var label_outlen;
|
||||
var settings_menu;
|
||||
var anote_menu;
|
||||
var anote_input;
|
||||
var anote_labelcur;
|
||||
@ -40,6 +34,40 @@ var do_clear_ent = false;
|
||||
// METHODS
|
||||
//=================================================================//
|
||||
|
||||
function addSetting(ob) {
|
||||
// Add setting block to Settings Menu
|
||||
settings_menu.append("<div class=\"settingitem\">\
|
||||
<div class=\"settinglabel\">\
|
||||
<div class=\"justifyleft\">\
|
||||
"+ob.label+" <span class=\"helpicon\">?<span class=\"helptext\">"+ob.tooltip+"</span></span>\
|
||||
</div>\
|
||||
<div class=\"justifyright\" id=\""+ob.id+"cur\">\
|
||||
"+ob.default+"\
|
||||
</div>\
|
||||
</div>\
|
||||
<div>\
|
||||
<input type=\"range\" class=\"form-range airange\" min=\""+ob.min+"\" max=\""+ob.max+"\" step=\""+ob.step+"\" id=\""+ob.id+"\">\
|
||||
</div>\
|
||||
<div class=\"settingminmax\">\
|
||||
<div class=\"justifyleft\">\
|
||||
"+ob.min+"\
|
||||
</div>\
|
||||
<div class=\"justifyright\">\
|
||||
"+ob.max+"\
|
||||
</div>\
|
||||
</div>\
|
||||
</div>");
|
||||
// Set references to HTML objects
|
||||
refin = $("#"+ob.id);
|
||||
reflb = $("#"+ob.id+"cur");
|
||||
window["setting_"+ob.id] = refin;
|
||||
window["label_"+ob.id] = reflb;
|
||||
// Add event function to input
|
||||
refin.on("input", function () {
|
||||
socket.send({'cmd': $(this).attr('id'), 'data': $(this).val()});
|
||||
});
|
||||
}
|
||||
|
||||
function enableButtons(refs) {
|
||||
for(i=0; i<refs.length; i++) {
|
||||
refs[i].prop("disabled",false);
|
||||
@ -174,6 +202,7 @@ function newTextHighlight(ref) {
|
||||
$(document).ready(function(){
|
||||
|
||||
// Bind UI references
|
||||
connect_status = $('#connectstatus');
|
||||
button_newgame = $('#btn_newgame');
|
||||
button_save = $('#btn_save');
|
||||
button_load = $('#btn_load');
|
||||
@ -187,14 +216,7 @@ $(document).ready(function(){
|
||||
game_text = $('#gametext');
|
||||
input_text = $('#input_text');
|
||||
message_text = $('#messagefield');
|
||||
setting_temp = $('#settemp');
|
||||
setting_topp = $('#settopp');
|
||||
setting_reppen = $('#setreppen');
|
||||
setting_outlen = $('#setoutput');
|
||||
label_temp = $('#settempcur');
|
||||
label_topp = $('#settoppcur');
|
||||
label_reppen = $('#setreppencur');
|
||||
label_outlen = $('#setoutputcur');
|
||||
settings_menu = $("#settingsmenu");
|
||||
anote_menu = $('#anoterowcontainer');
|
||||
anote_input = $('#anoteinput');
|
||||
anote_labelcur = $('#anotecur');
|
||||
@ -206,9 +228,11 @@ $(document).ready(function(){
|
||||
socket.on('from_server', function(msg) {
|
||||
if(msg.cmd == "connected") {
|
||||
// Connected to Server Actions
|
||||
$('#connectstatus').html("<b>Connected to KoboldAI Process!</b>");
|
||||
$('#connectstatus').removeClass("color_orange");
|
||||
$('#connectstatus').addClass("color_green");
|
||||
connect_status.html("<b>Connected to KoboldAI Process!</b>");
|
||||
connect_status.removeClass("color_orange");
|
||||
connect_status.addClass("color_green");
|
||||
// Reset Settings Menu
|
||||
settings_menu.html("");
|
||||
} else if(msg.cmd == "updatescreen") {
|
||||
// Send game content to Game Screen
|
||||
game_text.html(msg.data);
|
||||
@ -259,33 +283,47 @@ $(document).ready(function(){
|
||||
newTextHighlight($("#n"+msg.data))
|
||||
} else if(msg.cmd == "updatetemp") {
|
||||
// Send current temp value to input
|
||||
setting_temp.val(parseFloat(msg.data));
|
||||
label_temp.html(msg.data);
|
||||
$("#settemp").val(parseFloat(msg.data));
|
||||
$("#settempcur").html(msg.data);
|
||||
} else if(msg.cmd == "updatetopp") {
|
||||
// Send current temp value to input
|
||||
setting_topp.val(parseFloat(msg.data));
|
||||
label_topp.html(msg.data);
|
||||
// Send current top p value to input
|
||||
$("#settopp").val(parseFloat(msg.data));
|
||||
$("#settoppcur").html(msg.data);
|
||||
} else if(msg.cmd == "updatereppen") {
|
||||
// Send current temp value to input
|
||||
setting_reppen.val(parseFloat(msg.data));
|
||||
label_reppen.html(msg.data);
|
||||
// Send current rep pen value to input
|
||||
$("#setreppen").val(parseFloat(msg.data));
|
||||
$("#setreppencur").html(msg.data);
|
||||
} else if(msg.cmd == "updateoutlen") {
|
||||
// Send current temp value to input
|
||||
setting_outlen.val(parseInt(msg.data));
|
||||
label_outlen.html(msg.data);
|
||||
// Send current output amt value to input
|
||||
$("#setoutput").val(parseInt(msg.data));
|
||||
$("#setoutputcur").html(msg.data);
|
||||
} else if(msg.cmd == "updatetknmax") {
|
||||
// Send current max tokens value to input
|
||||
$("#settknmax").val(parseInt(msg.data));
|
||||
$("#settknmaxcur").html(msg.data);
|
||||
} else if(msg.cmd == "updateikgen") {
|
||||
// Send current max tokens value to input
|
||||
$("#setikgen").val(parseInt(msg.data));
|
||||
$("#setikgencur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeltemp") {
|
||||
// Update setting label with value from server
|
||||
label_temp.html(msg.data);
|
||||
$("#settempcur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeltopp") {
|
||||
// Update setting label with value from server
|
||||
label_topp.html(msg.data);
|
||||
$("#settoppcur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabelreppen") {
|
||||
// Update setting label with value from server
|
||||
label_reppen.html(msg.data);
|
||||
$("#setreppencur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeloutput") {
|
||||
// Update setting label with value from server
|
||||
label_outlen.html(msg.data);
|
||||
} else if(msg.cmd == "updatanotedepth") {
|
||||
$("#setoutputcur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeltknmax") {
|
||||
// Update setting label with value from server
|
||||
$("#settknmaxcur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabelikgen") {
|
||||
// Update setting label with value from server
|
||||
$("#setikgencur").html(msg.data);
|
||||
} else if(msg.cmd == "updateanotedepth") {
|
||||
// Send current Author's Note depth value to input
|
||||
anote_slider.val(parseInt(msg.data));
|
||||
anote_labelcur.html(msg.data);
|
||||
@ -299,13 +337,16 @@ $(document).ready(function(){
|
||||
} else if(msg.cmd == "setanote") {
|
||||
// Set contents of Author's Note field
|
||||
anote_input.val(msg.data);
|
||||
} else if(msg.cmd == "addsetting") {
|
||||
// Add setting controls
|
||||
addSetting(msg.data);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('disconnect', function() {
|
||||
$('#connectstatus').html("<b>Lost connection...</b>");
|
||||
$('#connectstatus').removeClass("color_green");
|
||||
$('#connectstatus').addClass("color_orange");
|
||||
connect_status.html("<b>Lost connection...</b>");
|
||||
connect_status.removeClass("color_green");
|
||||
connect_status.addClass("color_orange");
|
||||
});
|
||||
|
||||
// Bind actions to UI buttons
|
||||
@ -349,9 +390,8 @@ $(document).ready(function(){
|
||||
$('#settingsmenu').slideToggle("slow");
|
||||
});
|
||||
|
||||
// Bind settings to server calls
|
||||
$('input[type=range]').on('input', function () {
|
||||
socket.send({'cmd': $(this).attr('id'), 'data': $(this).val()});
|
||||
$("#btn_savesettings").on("click", function(ev) {
|
||||
socket.send({'cmd': 'savesettings', 'data': ''});
|
||||
});
|
||||
|
||||
// Bind Enter button to submit
|
||||
|
@ -170,6 +170,49 @@ chunk {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.helpicon {
|
||||
display: inline-block;
|
||||
font-family: sans-serif;
|
||||
font-weight: bold;
|
||||
text-align: center;
|
||||
width: 2.2ex;
|
||||
height: 2.4ex;
|
||||
font-size: 1.4ex;
|
||||
line-height: 1.8ex;
|
||||
border-radius: 1.2ex;
|
||||
margin-right: 4px;
|
||||
padding: 1px;
|
||||
color: #295071;
|
||||
background: #ffffff;
|
||||
border: 1px solid white;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.helpicon:hover {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.helpicon:hover .helptext {
|
||||
display: inline-block;
|
||||
width: 250px;
|
||||
background-color: #1f2931;
|
||||
color: #ffffff;
|
||||
font-size: 11pt;
|
||||
font-weight: normal;
|
||||
line-height: normal;
|
||||
border-radius: 6px;
|
||||
padding: 15px;
|
||||
margin-left:10px;
|
||||
border: 1px solid #337ab7;
|
||||
|
||||
position: absolute;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.helptext {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.justifyleft {
|
||||
text-align: left;
|
||||
}
|
||||
@ -179,12 +222,17 @@ chunk {
|
||||
}
|
||||
|
||||
.settingitem {
|
||||
width: 20%;
|
||||
width: 18%;
|
||||
padding-left: 10px;
|
||||
padding-right: 10px;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.settingsave {
|
||||
width: 10%;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.settinglabel {
|
||||
color: #ffffff;
|
||||
display: grid;
|
||||
|
@ -29,90 +29,6 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="row" id="settingsmenu">
|
||||
<div class="settingitem">
|
||||
<div class="settinglabel">
|
||||
<div class="justifyleft">
|
||||
Temperature
|
||||
</div>
|
||||
<div class="justifyright" id="settempcur">
|
||||
0.0
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<input type="range" class="form-range airange" min="0.1" max="2" step="0.05" id="settemp">
|
||||
</div>
|
||||
<div class="settingminmax">
|
||||
<div class="justifyleft">
|
||||
0.10
|
||||
</div>
|
||||
<div class="justifyright">
|
||||
2.00
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="settingitem">
|
||||
<div class="settinglabel">
|
||||
<div class="justifyleft">
|
||||
Top p Sampling
|
||||
</div>
|
||||
<div class="justifyright" id="settoppcur">
|
||||
0.0
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<input type="range" class="form-range airange" min="0.1" max="1" step="0.05" id="settopp">
|
||||
</div>
|
||||
<div class="settingminmax">
|
||||
<div class="justifyleft">
|
||||
0.10
|
||||
</div>
|
||||
<div class="justifyright">
|
||||
1.00
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="settingitem">
|
||||
<div class="settinglabel">
|
||||
<div class="justifyleft">
|
||||
Repetition Penalty
|
||||
</div>
|
||||
<div class="justifyright" id="setreppencur">
|
||||
0.0
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<input type="range" class="form-range airange" min="1" max="2" step="0.05" id="setreppen">
|
||||
</div>
|
||||
<div class="settingminmax">
|
||||
<div class="justifyleft">
|
||||
1.00
|
||||
</div>
|
||||
<div class="justifyright">
|
||||
2.00
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="settingitem">
|
||||
<div class="settinglabel">
|
||||
<div class="justifyleft">
|
||||
Amount to Generate
|
||||
</div>
|
||||
<div class="justifyright" id="setoutputcur">
|
||||
0
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<input type="range" class="form-range airange" min="10" max="500" step="2" id="setoutput">
|
||||
</div>
|
||||
<div class="settingminmax">
|
||||
<div class="justifyleft">
|
||||
10
|
||||
</div>
|
||||
<div class="justifyright">
|
||||
500
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row" id="gamescreen">
|
||||
<span id="gametext">...</span>
|
||||
|
19
utils.py
Normal file
19
utils.py
Normal file
@ -0,0 +1,19 @@
|
||||
from threading import Timer
|
||||
|
||||
def debounce(wait):
|
||||
def decorator(fun):
|
||||
def debounced(*args, **kwargs):
|
||||
def call_it():
|
||||
fun(*args, **kwargs)
|
||||
|
||||
try:
|
||||
debounced.t.cancel()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
debounced.t = Timer(wait, call_it)
|
||||
debounced.t.start()
|
||||
|
||||
return debounced
|
||||
|
||||
return decorator
|
Loading…
x
Reference in New Issue
Block a user