diff --git a/aiserver.py b/aiserver.py index 3638fec4..188d26bb 100644 --- a/aiserver.py +++ b/aiserver.py @@ -12,7 +12,8 @@ import tkinter as tk from tkinter import messagebox import json import collections -from typing import Union +import zipfile +from typing import Union, Tuple import requests import html @@ -103,7 +104,9 @@ class vars: formatoptns = {'frmttriminc': True, 'frmtrmblln': False, 'frmtrmspch': False, 'frmtadsnsp': False} # Container for state of formatting options importnum = -1 # Selection on import popup list importjs = {} # Temporary storage for import data - loadselect = "" # Temporary storage for filename to load + loadselect = "" # Temporary storage for story filename to load + spselect = "" # Temporary storage for soft prompt filename to load + sp = None # Current soft prompt tensor (as a NumPy array) svowname = "" # Filename that was flagged for overwrite confirm saveow = False # Whether or not overwrite confirm has been displayed genseqs = [] # Temporary storage for generated sequences @@ -113,6 +116,8 @@ class vars: bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently) smandelete = False # Whether stories can be deleted from inside the browser smanrename = False # Whether stories can be renamed from inside the browser + allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J) + modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B) laststory = None # Filename (without extension) of most recent story JSON file we loaded acregex_ai = re.compile(r'\n* *>(.|\n)*') # Pattern for matching adventure actions from the AI so we can remove them acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) @@ -311,6 +316,7 @@ else: # If transformers model was selected & GPU available, ask to use CPU or GPU if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): + vars.allowsp = True # Test for GPU support import torch print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") @@ -503,11 +509,31 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM - + + # Patch transformers to use our soft prompt + def patch_causallm(cls): + old_forward = cls.forward + def new_causallm_forward(self, *args, **kwargs): + input_ids = kwargs.get('input_ids').to(self.device) + assert input_ids is not None + kwargs['input_ids'] = None + inputs_embeds = self.transformer.wte(input_ids) + if(vars.sp is not None): + inputs_embeds = torch.cat(( + vars.sp.tile((inputs_embeds.shape[0], 1, 1)), + inputs_embeds + ), dim=1).to(self.device) + kwargs['inputs_embeds'] = inputs_embeds + return old_forward(*args, **kwargs) + cls.forward = new_causallm_forward + for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): + patch_causallm(cls) + # If custom GPT Neo model was chosen if(vars.model == "NeoCustom"): model_config = open(vars.custmodpth + "/config.json", "r") js = json.load(model_config) + vars.modeldim = int(js['hidden_size']) if("model_type" in js): model = AutoModelForCausalLM.from_pretrained(vars.custmodpth) else: @@ -525,6 +551,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): generator = pipeline('text-generation', model=model, tokenizer=tokenizer) # If custom GPT2 model was chosen elif(vars.model == "GPT2Custom"): + model_config = open(vars.custmodpth + "/config.json", "r") + js = json.load(model_config) + vars.modeldim = int(js['hidden_size']) model = GPT2LMHeadModel.from_pretrained(vars.custmodpth) tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth) # Is CUDA available? If so, use GPU, otherwise fall back to CPU @@ -538,13 +567,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): tokenizer = GPT2Tokenizer.from_pretrained(vars.model) if(vars.hascuda): if(vars.usegpu): - generator = pipeline('text-generation', model=vars.model, device=0) + model = AutoModelForCausalLM.from_pretrained(vars.model, device=0) + vars.modeldim = int(model.transformer.hidden_size) + generator = pipeline('text-generation', model=model, device=0) elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) model = AutoModelForCausalLM.from_pretrained(vars.model) device_config(model) else: + model = AutoModelForCausalLM.from_pretrained(vars.model) generator = pipeline('text-generation', model=vars.model) else: + model = AutoModelForCausalLM.from_pretrained(vars.model) generator = pipeline('text-generation', model=vars.model) # Suppress Author's Note by flagging square brackets (Old implementation) @@ -807,10 +840,16 @@ def get_message(msg): save() elif(msg['cmd'] == 'loadlistrequest'): getloadlist() + elif(msg['cmd'] == 'splistrequest'): + getsplist() elif(msg['cmd'] == 'loadselect'): vars.loadselect = msg["data"] + elif(msg['cmd'] == 'spselect'): + vars.spselect = msg["data"] elif(msg['cmd'] == 'loadrequest'): loadRequest(fileops.storypath(vars.loadselect)) + elif(msg['cmd'] == 'sprequest'): + spRequest(vars.spselect) elif(msg['cmd'] == 'deletestory'): deletesave(msg['data']) elif(msg['cmd'] == 'renamestory'): @@ -846,6 +885,8 @@ def get_message(msg): #==================================================================# def setStartState(): txt = "Welcome to KoboldAI! You are running "+getmodelname()+".
" + if(vars.allowsp): + emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}, broadcast=True) if(not vars.noai): txt = txt + "Please load a game or enter a prompt below to begin!
" else: @@ -1123,10 +1164,12 @@ def calcsubmit(txt): anotetkns = tokenizer.encode(anotetxt) lnanote = len(anotetkns) + lnsp = vars.sp.shape[0] if vars.sp is not None else 0 + if(vars.useprompt): - budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt + budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt else: - budget = vars.max_length - lnmem - lnanote - lnwi - vars.genamt + budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt if(actionlen == 0): # First/Prompt action @@ -2131,11 +2174,18 @@ def saveRequest(savpath): print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END)) #==================================================================# -# Load a saved story via file browser +# Show list of saved stories #==================================================================# def getloadlist(): emit('from_server', {'cmd': 'buildload', 'data': fileops.getstoryfiles()}) +#==================================================================# +# Show list of soft prompts +#==================================================================# +def getsplist(): + if(vars.allowsp): + emit('from_server', {'cmd': 'buildsp', 'data': fileops.getspfiles(vars.modeldim)}) + #==================================================================# # Load a saved story via file browser #==================================================================# @@ -2221,6 +2271,35 @@ def loadRequest(loadpath): emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True) print("{0}Story loaded from {1}!{2}".format(colors.GREEN, path.basename(loadpath), colors.END)) +#==================================================================# +# Load a soft prompt from a file +#==================================================================# +def spRequest(filename): + if(len(filename) == 0): + vars.sp = None + return + + global np + if 'np' not in globals(): + import numpy as np + + z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim) + assert isinstance(z, zipfile.ZipFile) + + with z.open('tensor.npy') as f: + tensor = np.load(f, allow_pickle=False) + + # If the tensor is in bfloat16 format, convert it to float32 + if(tensor.dtype == 'V2'): + tensor.dtype = np.uint16 + tensor = np.uint32(tensor) << 16 + tensor.dtype = np.float32 + + tensor = np.float16(tensor) + assert not np.isinf(tensor).any() and not np.isnan(tensor).any() + + vars.sp = torch.from_numpy(tensor) + #==================================================================# # Import an AIDungon game exported with Mimi's tool #==================================================================# diff --git a/fileops.py b/fileops.py index 7fcc44cb..c5ff440c 100644 --- a/fileops.py +++ b/fileops.py @@ -1,8 +1,10 @@ import tkinter as tk from tkinter import filedialog from os import getcwd, listdir, path +from typing import Tuple, Union, Optional import os import json +import zipfile #==================================================================# # Generic Method for prompting for file path @@ -61,6 +63,12 @@ def getdirpath(dir, title): def storypath(name): return path.join(path.dirname(path.realpath(__file__)), "stories", name + ".json") +#==================================================================# +# Returns the path (as a string) to the given soft prompt by its filename +#==================================================================# +def sppath(filename): + return path.join(path.dirname(path.realpath(__file__)), "softprompts", filename) + #==================================================================# # Returns an array of dicts containing story files in /stories #==================================================================# @@ -86,6 +94,70 @@ def getstoryfiles(): list.append(ob) return list +#==================================================================# +# Checks if the given soft prompt file is valid +#==================================================================# +def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, int], Optional[Tuple[int, int]], Optional[Tuple[int, int]], Optional[bool], Optional['np.dtype']]: + global np + if 'np' not in globals(): + import numpy as np + try: + z = zipfile.ZipFile(path.dirname(path.realpath(__file__))+"/softprompts/"+filename) + with z.open('tensor.npy') as f: + # Read only the header of the npy file, for efficiency reasons + version: Tuple[int, int] = np.lib.format.read_magic(f) + shape: Tuple[int, int] + fortran_order: bool + dtype: np.dtype + shape, fortran_order, dtype = np.lib.format._read_array_header(f, version) + assert len(shape) == 2 + except: + z.close() + return 1, None, None, None, None + if dtype not in ('V2', np.float16, np.float32): + z.close() + return 2, version, shape, fortran_order, dtype + if shape[1] != model_dimension: + z.close() + return 3, version, shape, fortran_order, dtype + if shape[0] >= 2048: + z.close() + return 4, version, shape, fortran_order, dtype + return z, version, shape, fortran_order, dtype + +#==================================================================# +# Returns an array of dicts containing softprompt files in /softprompts +#==================================================================# +def getspfiles(model_dimension: int): + lst = [] + os.makedirs(path.dirname(path.realpath(__file__))+"/softprompts", exist_ok=True) + for file in listdir(path.dirname(path.realpath(__file__))+"/softprompts"): + if not file.endswith(".zip"): + continue + z, version, shape, fortran_order, dtype = checksp(file, model_dimension) + if z == 1: + print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.") + continue + if z == 2: + print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.") + continue + if z == 3: + print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.") + continue + if z == 4: + print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.") + continue + assert isinstance(z, zipfile.ZipFile) + try: + with z.open('meta.json') as f: + ob = json.load(f) + except: + ob = {} + z.close() + ob["filename"] = file + lst.append(ob) + return lst + #==================================================================# # Returns True if json file exists with requested save name #==================================================================# diff --git a/static/application.js b/static/application.js index 86bee330..a5f6df51 100644 --- a/static/application.js +++ b/static/application.js @@ -18,6 +18,7 @@ var button_importwi; var button_impaidg; var button_settings; var button_format; +var button_softprompt; var button_mode; var button_mode_label; var button_send; @@ -53,6 +54,10 @@ var loadpopup; var loadcontent; var load_accept; var load_close; +var sppopup; +var spcontent; +var sp_accept; +var sp_close; var nspopup; var ns_accept; var ns_close; @@ -77,6 +82,7 @@ var saved_prompt = "..."; var override_focusout = false; var sman_allow_delete = false; var sman_allow_rename = false; +var allowsp = false; // This is true iff [we're in macOS and the browser is Safari] or [we're in iOS] var using_webkit_patch = true; @@ -589,6 +595,17 @@ function hideLoadPopup() { loadcontent.html(""); } +function showSPPopup() { + sppopup.removeClass("hidden"); + sppopup.addClass("flex"); +} + +function hideSPPopup() { + sppopup.removeClass("flex"); + sppopup.addClass("hidden"); + spcontent.html(""); +} + function buildLoadList(ar) { disableButtons([load_accept]); loadcontent.html(""); @@ -654,11 +671,51 @@ function buildLoadList(ar) { } } +function buildSPList(ar) { + disableButtons([sp_accept]); + spcontent.html(""); + showSPPopup(); + ar.push({filename: '', name: "[None]"}) + for(var i = 0; i < ar.length; i++) { + var supported = !ar[i].supported + ? '' + : Object.prototype.toString.call(ar[i].supported) === "[object Array]" + ? "[" + ar[i].supported.join(', ') + "]" + : "[" + ar[i].supported.toString() + "]"; + var name = ar[i].name || ar[i].filename; + name = name.length > 120 ? name.slice(0, 117) + '...' : name; + var desc = ar[i].description || ''; + desc = desc.length > 500 ? desc.slice(0, 497) + '...' : desc; + spcontent.append("
\ +
\ +
\ +
"+name+"
\ +
"+ar[i].filename+"
\ +
\ +
\ +
"+desc+"
\ +
"+supported+"
\ +
\ +
\ +
"); + $("#sp"+i).on("click", function () { + enableButtons([sp_accept]); + socket.send({'cmd': 'spselect', 'data': $(this).attr("name")}); + highlightSPLine($(this)); + }); + } +} + function highlightLoadLine(ref) { $("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected"); ref.addClass("popuplistselected"); } +function highlightSPLine(ref) { + $("#splistcontent > div > div.popuplistselected").removeClass("popuplistselected"); + ref.addClass("popuplistselected"); +} + function showNewStoryPopup() { nspopup.removeClass("hidden"); nspopup.addClass("flex"); @@ -1142,6 +1199,7 @@ $(document).ready(function(){ button_impaidg = $("#btn_impaidg"); button_settings = $('#btn_settings'); button_format = $('#btn_format'); + button_softprompt = $("#btn_softprompt"); button_mode = $('#btnmode') button_mode_label = $('#btnmode_label') button_send = $('#btnsend'); @@ -1177,6 +1235,10 @@ $(document).ready(function(){ loadcontent = $("#loadlistcontent"); load_accept = $("#btn_loadaccept"); load_close = $("#btn_loadclose"); + sppopup = $("#spcontainer"); + spcontent = $("#splistcontent"); + sp_accept = $("#btn_spaccept"); + sp_close = $("#btn_spclose"); nspopup = $("#newgamecontainer"); ns_accept = $("#btn_nsaccept"); ns_close = $("#btn_nsclose"); @@ -1314,6 +1376,13 @@ $(document).ready(function(){ } else if(msg.data == "start") { setStartState(); } + } else if(msg.cmd == "allowsp") { + allowsp = !!msg.data; + if(allowsp) { + button_softprompt.removeClass("hidden"); + } else { + button_softprompt.addClass("hidden"); + } } else if(msg.cmd == "setstoryname") { storyname = msg.data; } else if(msg.cmd == "editmode") { @@ -1480,6 +1549,8 @@ $(document).ready(function(){ } else if(msg.cmd == "buildload") { // Send array of save files to load UI buildLoadList(msg.data); + } else if(msg.cmd == "buildsp") { + buildSPList(msg.data); } else if(msg.cmd == "askforoverwrite") { // Show overwrite warning show([$(".saveasoverwrite")]); @@ -1654,6 +1725,10 @@ $(document).ready(function(){ button_load.on("click", function(ev) { socket.send({'cmd': 'loadlistrequest', 'data': ''}); }); + + button_softprompt.on("click", function(ev) { + socket.send({'cmd': 'splistrequest', 'data': ''}); + }); load_close.on("click", function(ev) { hideLoadPopup(); @@ -1664,6 +1739,15 @@ $(document).ready(function(){ socket.send({'cmd': 'loadrequest', 'data': ''}); hideLoadPopup(); }); + + sp_close.on("click", function(ev) { + hideSPPopup(); + }); + + sp_accept.on("click", function(ev) { + socket.send({'cmd': 'sprequest', 'data': ''}); + hideSPPopup(); + }); button_newgame.on("click", function(ev) { showNewStoryPopup(); diff --git a/static/custom.css b/static/custom.css index 2d12c00c..3b1ee1dc 100644 --- a/static/custom.css +++ b/static/custom.css @@ -307,6 +307,38 @@ chunk.editing, chunk.editing * { overflow-y: scroll; } +#sppopup { + width: 500px; + background-color: #262626; + margin-top: 100px; +} + +@media (max-width: 768px) { + #loadpopup { + width: 100%; + background-color: #262626; + margin-top: 100px; + } +} + +#sppopupdelete { + width: 350px; + background-color: #262626; + margin-top: 200px; +} + +#sppopuprename { + width: 350px; + background-color: #262626; + margin-top: 200px; +} + +#splistcontent { + height: 325px; + overflow-y: scroll; + overflow-wrap: anywhere; +} + #nspopup { width: 350px; background-color: #262626; @@ -423,6 +455,18 @@ chunk.editing, chunk.editing * { align-items: center; } +.flex-row-container { + display: flex; + flex-flow: wrap; +} + +.flex-row { + display: flex; + flex-flow: row; + flex-grow: 1; + width: 100%; +} + .flex-push-right { margin-left: auto; } @@ -805,6 +849,34 @@ chunk.editing, chunk.editing * { width: 50px; } +.splistheader { + padding-left: 68px; + padding-right: 20px; + display: flex; + color: #737373; +} + +.splistitem { + padding: 12px 10px 12px 10px; + display: flex; + flex-grow: 1; + color: #ffffff; + + -moz-transition: background-color 0.25s ease-in; + -o-transition: background-color 0.25s ease-in; + -webkit-transition: background-color 0.25s ease-in; + transition: background-color 0.25s ease-in; +} + +.splistitemsub { + color: #ba9; +} + +.splistitem:hover { + cursor: pointer; + background-color: #688f1f; +} + .width-auto { width: auto; } diff --git a/templates/index.html b/templates/index.html index 1de19e28..17d39019 100644 --- a/templates/index.html +++ b/templates/index.html @@ -6,14 +6,14 @@ - + - + @@ -67,6 +67,9 @@ + @@ -225,6 +228,19 @@ +