diff --git a/aiserver.py b/aiserver.py
index 3638fec4..188d26bb 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -12,7 +12,8 @@ import tkinter as tk
from tkinter import messagebox
import json
import collections
-from typing import Union
+import zipfile
+from typing import Union, Tuple
import requests
import html
@@ -103,7 +104,9 @@ class vars:
formatoptns = {'frmttriminc': True, 'frmtrmblln': False, 'frmtrmspch': False, 'frmtadsnsp': False} # Container for state of formatting options
importnum = -1 # Selection on import popup list
importjs = {} # Temporary storage for import data
- loadselect = "" # Temporary storage for filename to load
+ loadselect = "" # Temporary storage for story filename to load
+ spselect = "" # Temporary storage for soft prompt filename to load
+ sp = None # Current soft prompt tensor (as a NumPy array)
svowname = "" # Filename that was flagged for overwrite confirm
saveow = False # Whether or not overwrite confirm has been displayed
genseqs = [] # Temporary storage for generated sequences
@@ -113,6 +116,8 @@ class vars:
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
smandelete = False # Whether stories can be deleted from inside the browser
smanrename = False # Whether stories can be renamed from inside the browser
+ allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J)
+ modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
laststory = None # Filename (without extension) of most recent story JSON file we loaded
acregex_ai = re.compile(r'\n* *>(.|\n)*') # Pattern for matching adventure actions from the AI so we can remove them
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
@@ -311,6 +316,7 @@ else:
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
+ vars.allowsp = True
# Test for GPU support
import torch
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@@ -503,11 +509,31 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
-
+
+ # Patch transformers to use our soft prompt
+ def patch_causallm(cls):
+ old_forward = cls.forward
+ def new_causallm_forward(self, *args, **kwargs):
+ input_ids = kwargs.get('input_ids').to(self.device)
+ assert input_ids is not None
+ kwargs['input_ids'] = None
+ inputs_embeds = self.transformer.wte(input_ids)
+ if(vars.sp is not None):
+ inputs_embeds = torch.cat((
+ vars.sp.tile((inputs_embeds.shape[0], 1, 1)),
+ inputs_embeds
+ ), dim=1).to(self.device)
+ kwargs['inputs_embeds'] = inputs_embeds
+ return old_forward(*args, **kwargs)
+ cls.forward = new_causallm_forward
+ for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
+ patch_causallm(cls)
+
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
+ vars.modeldim = int(js['hidden_size'])
if("model_type" in js):
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth)
else:
@@ -525,6 +551,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# If custom GPT2 model was chosen
elif(vars.model == "GPT2Custom"):
+ model_config = open(vars.custmodpth + "/config.json", "r")
+ js = json.load(model_config)
+ vars.modeldim = int(js['hidden_size'])
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth)
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth)
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
@@ -538,13 +567,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]):
tokenizer = GPT2Tokenizer.from_pretrained(vars.model)
if(vars.hascuda):
if(vars.usegpu):
- generator = pipeline('text-generation', model=vars.model, device=0)
+ model = AutoModelForCausalLM.from_pretrained(vars.model, device=0)
+ vars.modeldim = int(model.transformer.hidden_size)
+ generator = pipeline('text-generation', model=model, device=0)
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
model = AutoModelForCausalLM.from_pretrained(vars.model)
device_config(model)
else:
+ model = AutoModelForCausalLM.from_pretrained(vars.model)
generator = pipeline('text-generation', model=vars.model)
else:
+ model = AutoModelForCausalLM.from_pretrained(vars.model)
generator = pipeline('text-generation', model=vars.model)
# Suppress Author's Note by flagging square brackets (Old implementation)
@@ -807,10 +840,16 @@ def get_message(msg):
save()
elif(msg['cmd'] == 'loadlistrequest'):
getloadlist()
+ elif(msg['cmd'] == 'splistrequest'):
+ getsplist()
elif(msg['cmd'] == 'loadselect'):
vars.loadselect = msg["data"]
+ elif(msg['cmd'] == 'spselect'):
+ vars.spselect = msg["data"]
elif(msg['cmd'] == 'loadrequest'):
loadRequest(fileops.storypath(vars.loadselect))
+ elif(msg['cmd'] == 'sprequest'):
+ spRequest(vars.spselect)
elif(msg['cmd'] == 'deletestory'):
deletesave(msg['data'])
elif(msg['cmd'] == 'renamestory'):
@@ -846,6 +885,8 @@ def get_message(msg):
#==================================================================#
def setStartState():
txt = "Welcome to KoboldAI! You are running "+getmodelname()+".
"
+ if(vars.allowsp):
+ emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp}, broadcast=True)
if(not vars.noai):
txt = txt + "Please load a game or enter a prompt below to begin!"
else:
@@ -1123,10 +1164,12 @@ def calcsubmit(txt):
anotetkns = tokenizer.encode(anotetxt)
lnanote = len(anotetkns)
+ lnsp = vars.sp.shape[0] if vars.sp is not None else 0
+
if(vars.useprompt):
- budget = vars.max_length - lnprompt - lnmem - lnanote - lnwi - vars.genamt
+ budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
else:
- budget = vars.max_length - lnmem - lnanote - lnwi - vars.genamt
+ budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
if(actionlen == 0):
# First/Prompt action
@@ -2131,11 +2174,18 @@ def saveRequest(savpath):
print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END))
#==================================================================#
-# Load a saved story via file browser
+# Show list of saved stories
#==================================================================#
def getloadlist():
emit('from_server', {'cmd': 'buildload', 'data': fileops.getstoryfiles()})
+#==================================================================#
+# Show list of soft prompts
+#==================================================================#
+def getsplist():
+ if(vars.allowsp):
+ emit('from_server', {'cmd': 'buildsp', 'data': fileops.getspfiles(vars.modeldim)})
+
#==================================================================#
# Load a saved story via file browser
#==================================================================#
@@ -2221,6 +2271,35 @@ def loadRequest(loadpath):
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, path.basename(loadpath), colors.END))
+#==================================================================#
+# Load a soft prompt from a file
+#==================================================================#
+def spRequest(filename):
+ if(len(filename) == 0):
+ vars.sp = None
+ return
+
+ global np
+ if 'np' not in globals():
+ import numpy as np
+
+ z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
+ assert isinstance(z, zipfile.ZipFile)
+
+ with z.open('tensor.npy') as f:
+ tensor = np.load(f, allow_pickle=False)
+
+ # If the tensor is in bfloat16 format, convert it to float32
+ if(tensor.dtype == 'V2'):
+ tensor.dtype = np.uint16
+ tensor = np.uint32(tensor) << 16
+ tensor.dtype = np.float32
+
+ tensor = np.float16(tensor)
+ assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
+
+ vars.sp = torch.from_numpy(tensor)
+
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
#==================================================================#
diff --git a/fileops.py b/fileops.py
index 7fcc44cb..c5ff440c 100644
--- a/fileops.py
+++ b/fileops.py
@@ -1,8 +1,10 @@
import tkinter as tk
from tkinter import filedialog
from os import getcwd, listdir, path
+from typing import Tuple, Union, Optional
import os
import json
+import zipfile
#==================================================================#
# Generic Method for prompting for file path
@@ -61,6 +63,12 @@ def getdirpath(dir, title):
def storypath(name):
return path.join(path.dirname(path.realpath(__file__)), "stories", name + ".json")
+#==================================================================#
+# Returns the path (as a string) to the given soft prompt by its filename
+#==================================================================#
+def sppath(filename):
+ return path.join(path.dirname(path.realpath(__file__)), "softprompts", filename)
+
#==================================================================#
# Returns an array of dicts containing story files in /stories
#==================================================================#
@@ -86,6 +94,70 @@ def getstoryfiles():
list.append(ob)
return list
+#==================================================================#
+# Checks if the given soft prompt file is valid
+#==================================================================#
+def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, int], Optional[Tuple[int, int]], Optional[Tuple[int, int]], Optional[bool], Optional['np.dtype']]:
+ global np
+ if 'np' not in globals():
+ import numpy as np
+ try:
+ z = zipfile.ZipFile(path.dirname(path.realpath(__file__))+"/softprompts/"+filename)
+ with z.open('tensor.npy') as f:
+ # Read only the header of the npy file, for efficiency reasons
+ version: Tuple[int, int] = np.lib.format.read_magic(f)
+ shape: Tuple[int, int]
+ fortran_order: bool
+ dtype: np.dtype
+ shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
+ assert len(shape) == 2
+ except:
+ z.close()
+ return 1, None, None, None, None
+ if dtype not in ('V2', np.float16, np.float32):
+ z.close()
+ return 2, version, shape, fortran_order, dtype
+ if shape[1] != model_dimension:
+ z.close()
+ return 3, version, shape, fortran_order, dtype
+ if shape[0] >= 2048:
+ z.close()
+ return 4, version, shape, fortran_order, dtype
+ return z, version, shape, fortran_order, dtype
+
+#==================================================================#
+# Returns an array of dicts containing softprompt files in /softprompts
+#==================================================================#
+def getspfiles(model_dimension: int):
+ lst = []
+ os.makedirs(path.dirname(path.realpath(__file__))+"/softprompts", exist_ok=True)
+ for file in listdir(path.dirname(path.realpath(__file__))+"/softprompts"):
+ if not file.endswith(".zip"):
+ continue
+ z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
+ if z == 1:
+ print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.")
+ continue
+ if z == 2:
+ print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.")
+ continue
+ if z == 3:
+ print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
+ continue
+ if z == 4:
+ print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
+ continue
+ assert isinstance(z, zipfile.ZipFile)
+ try:
+ with z.open('meta.json') as f:
+ ob = json.load(f)
+ except:
+ ob = {}
+ z.close()
+ ob["filename"] = file
+ lst.append(ob)
+ return lst
+
#==================================================================#
# Returns True if json file exists with requested save name
#==================================================================#
diff --git a/static/application.js b/static/application.js
index 86bee330..a5f6df51 100644
--- a/static/application.js
+++ b/static/application.js
@@ -18,6 +18,7 @@ var button_importwi;
var button_impaidg;
var button_settings;
var button_format;
+var button_softprompt;
var button_mode;
var button_mode_label;
var button_send;
@@ -53,6 +54,10 @@ var loadpopup;
var loadcontent;
var load_accept;
var load_close;
+var sppopup;
+var spcontent;
+var sp_accept;
+var sp_close;
var nspopup;
var ns_accept;
var ns_close;
@@ -77,6 +82,7 @@ var saved_prompt = "...";
var override_focusout = false;
var sman_allow_delete = false;
var sman_allow_rename = false;
+var allowsp = false;
// This is true iff [we're in macOS and the browser is Safari] or [we're in iOS]
var using_webkit_patch = true;
@@ -589,6 +595,17 @@ function hideLoadPopup() {
loadcontent.html("");
}
+function showSPPopup() {
+ sppopup.removeClass("hidden");
+ sppopup.addClass("flex");
+}
+
+function hideSPPopup() {
+ sppopup.removeClass("flex");
+ sppopup.addClass("hidden");
+ spcontent.html("");
+}
+
function buildLoadList(ar) {
disableButtons([load_accept]);
loadcontent.html("");
@@ -654,11 +671,51 @@ function buildLoadList(ar) {
}
}
+function buildSPList(ar) {
+ disableButtons([sp_accept]);
+ spcontent.html("");
+ showSPPopup();
+ ar.push({filename: '', name: "[None]"})
+ for(var i = 0; i < ar.length; i++) {
+ var supported = !ar[i].supported
+ ? ''
+ : Object.prototype.toString.call(ar[i].supported) === "[object Array]"
+ ? "[" + ar[i].supported.join(', ') + "]"
+ : "[" + ar[i].supported.toString() + "]";
+ var name = ar[i].name || ar[i].filename;
+ name = name.length > 120 ? name.slice(0, 117) + '...' : name;
+ var desc = ar[i].description || '';
+ desc = desc.length > 500 ? desc.slice(0, 497) + '...' : desc;
+ spcontent.append("