Connect bridge.lua to aiserver.py

Also enables the use of input modifiers and output modifiers, but not
generation modifiers.
This commit is contained in:
Gnome Ann 2021-12-11 12:45:45 -05:00
parent 68685698a4
commit e289a0d360
2 changed files with 292 additions and 20 deletions

View File

@ -24,6 +24,8 @@ import argparse
import sys
import gc
import lupa
# KoboldAI
import fileops
import gensettings
@ -31,6 +33,11 @@ from utils import debounce
import utils
import structures
if lupa.LUA_VERSION[:2] != (5, 4):
print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr)
#==================================================================#
# Variables & Storage
#==================================================================#
@ -91,6 +98,10 @@ class vars:
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
wifolders_l = [] # List of World Info folder UIDs
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
lua_state = None # Lua state of the Lua scripting system
lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua
lua_koboldcore = None # `koboldcore` from bridge.lua
# badwords = [] # Array of str/chr values that should be removed from output
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
deletewi = -1 # Temporary storage for index to delete
@ -549,7 +560,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
try:
from transformers import GPTJModel
except:
@ -751,7 +762,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
else:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model)
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda):
if(vars.usegpu):
@ -769,7 +780,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
js = json.load(model_config)
with(maybe_use_float16()):
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
vars.modeldim = get_hidden_size_from_model(model)
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda and vars.usegpu):
@ -780,7 +791,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# If base HuggingFace model was chosen
else:
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
if(vars.hascuda):
if(vars.usegpu):
with(maybe_use_float16()):
@ -810,14 +821,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# vars.badwordsids.append([vocab[key]])
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
else:
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else:
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B")
elif(vars.model == "OAI"):
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
@ -875,6 +890,239 @@ def download():
save.headers.set('Content-Disposition', 'attachment', filename='%s.json' % filename)
return(save)
#============================ LUA API =============================#
#==================================================================#
# Event triggered when a userscript is loaded
#==================================================================#
def load_callback(filename):
print(colors.PURPLE + f"Loading Userscript [USERPLACEHOLDER] <{filename}>" + colors.END)
#==================================================================#
# Load all Lua scripts
#==================================================================#
def load_lua_scripts():
print(colors.PURPLE + "Loading Core Script [COREPLACEHOLDER] <default.lua>" + colors.END)
try:
vars.lua_koboldbridge.obliterate_multiverse()
vars.lua_koboldbridge.load_corescript("default.lua")
vars.lua_koboldbridge.load_userscripts([], [], [])
except lupa.LuaError as e:
print(e, file=sys.stderr)
exit(1)
#==================================================================#
# Decode tokens into a string using current tokenizer
#==================================================================#
def lua_decode(tokens):
assert type(tokens) is list
return tokenizer.decode(tokens)
#==================================================================#
# Encode string into list of token IDs using current tokenizer
#==================================================================#
def lua_encode(string):
assert type(string) is str
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================#
# Get property of a world info entry given its UID and property name
#==================================================================#
def lua_get_attr(uid, k):
assert type(uid) is int and type(k) is str
if(uid in vars.worldinfo_u and k in (
"key",
"keysecondary",
"content",
"comment",
"folder",
"num",
"selective",
"constant",
"uid",
)):
return vars.worldinfo_u[uid][k]
#==================================================================#
# Set property of a world info entry given its UID, property name and new value
#==================================================================#
def lua_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str
assert uid in vars.worldinfo_u and k in (
"key",
"keysecondary",
"content",
"comment",
"selective",
"constant",
)
if(type(vars.worldinfo_u[uid][k]) is int and type(v) is float):
v = int(v)
assert type(vars.worldinfo_u[uid][k]) is type(v)
vars.worldinfo_u[uid][k] = v
#==================================================================#
# Get property of a world info folder given its UID and property name
#==================================================================#
def lua_folder_get_attr(uid, k):
assert type(uid) is int and type(k) is str
if(uid in vars.wifolders_d and k in (
"comment",
)):
return vars.wifolders_d[uid][k]
#==================================================================#
# Set property of a world info folder given its UID, property name and new value
#==================================================================#
def lua_folder_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str
assert uid in vars.wifolders_d and k in (
"comment",
)
if(type(vars.wifolders_d[uid][k]) is int and type(v) is float):
v = int(v)
assert type(vars.wifolders_d[uid][k]) is type(v)
vars.wifolders_d[uid][k] = v
#==================================================================#
# Get the "Amount to Generate"
#==================================================================#
def lua_get_gen_len():
return vars.genamt
#==================================================================#
# Set the "Amount to Generate"
#==================================================================#
def lua_set_gen_len(genamt):
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
vars.genamt = int(genamt)
#==================================================================#
# Get the "Gens Per Action"
#==================================================================#
def lua_get_numseqs():
return vars.numseqs
#==================================================================#
# Set the "Gens Per Action"
#==================================================================#
def lua_set_numseqs(numseqs):
assert type(numseqs) in (int, float) and numseqs >= 1
vars.genamt = int(numseqs)
#==================================================================#
# Check if a setting exists with the given name
#==================================================================#
def lua_has_setting(setting):
return setting in (
"settemp",
"settopp",
"settopk",
"settfs",
"setreppen",
"setoutput",
"settknmax",
"anotedepth",
"setwidepth",
"setuseprompt",
"setadventure",
"frmttriminc",
"frmtrmblln",
"frmtrmspch",
"frmtadsnsp",
"singleline",
)
#==================================================================#
# Return the setting with the given name if it exists
#==================================================================#
def lua_get_setting(setting):
if(setting == "settemp"): return vars.temp
if(setting == "settopp"): return vars.top_p
if(setting == "settopk"): return vars.top_k
if(setting == "settfs"): return vars.tfs
if(setting == "setreppen"): return vars.rep_pen
if(setting == "settknmax"): return vars.max_length
if(setting == "anotedepth"): return vars.andepth
if(setting == "setwidepth"): return vars.widepth
if(setting == "setuseprompt"): return vars.useprompt
if(setting == "setadventure"): return vars.adventure
if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"]
if(setting == "singleline"): return vars.formatoptns["frmttriminc"]
#==================================================================#
# Set the setting with the given name if it exists
#==================================================================#
def lua_set_setting(setting, v):
actual_type = type(lua_get_setting(setting))
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
v = actual_type(v)
print(colors.PURPLE + f"[USERPLACEHOLDER] set {setting} to {v}" + colors.END)
if(setting == "setadventure" and v):
vars.actionmode = 1
get_message({'cmd': setting, 'data': v})
#==================================================================#
# Get contents of memory
#==================================================================#
def lua_get_memory():
return vars.memory
#==================================================================#
# Set contents of memory
#==================================================================#
def lua_set_memory(m):
assert type(m) is str
vars.memory = m
#==================================================================#
# Lua runtime startup
#==================================================================#
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="")
# Set up Lua state
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
# Load bridge.lua
bridged = {
"corescript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "corescripts"),
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "userscripts"),
"lib_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs"),
"load_callback": load_callback,
"decode": lua_decode,
"encode": lua_encode,
"get_attr": lua_get_attr,
"set_attr": lua_set_attr,
"get_gen_len": lua_get_gen_len,
"set_gen_len": lua_set_gen_len,
"get_memory": lua_get_memory,
"set_memory": lua_set_memory,
"get_numseqs": lua_get_numseqs,
"set_numseqs": lua_set_numseqs,
"get_setting": lua_get_setting,
"set_setting": lua_set_setting,
"vars": vars,
}
try:
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
vars.lua_state.globals().python,
bridged,
)
except lupa.LuaError as e:
print(colors.RED + "ERROR!" + colors.END)
print(e, file=sys.stderr)
exit(1)
print(colors.GREEN + "OK!" + colors.END)
# Load scripts
load_lua_scripts()
#============================ METHODS =============================#
#==================================================================#
@ -1314,6 +1562,9 @@ def actionsubmit(data, actionmode=0, force_submit=False):
vars.recentedit = False
vars.actionmode = actionmode
# Run the core script's input modifier
vars.lua_koboldbridge.execute_inmod()
# "Action" mode
if(actionmode == 1):
data = data.strip().lstrip('>')
@ -1339,6 +1590,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
calcsubmit(data) # Run the first action through the generator
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
vars.lua_koboldbridge.execute_outmod()
refresh_story()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
@ -1360,6 +1612,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
calcsubmit(data)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
vars.lua_koboldbridge.execute_outmod()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
@ -1701,6 +1954,8 @@ def generate(txt, minimum, maximum, found_entries=None):
#already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
vars.lua_koboldbridge.execute_outmod()
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
@ -1800,6 +2055,8 @@ def sendtocolab(txt, min, max):
else:
genout = js["seqs"]
vars.lua_koboldbridge.execute_outmod()
if(len(genout) == 1):
genresult(genout[0])
else:
@ -1881,6 +2138,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
genout = [{"generated_text": txt} for txt in genout]
vars.lua_koboldbridge.execute_outmod()
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
@ -2459,6 +2718,7 @@ def ikrequest(txt):
# Deal with the response
if(req.status_code == 200):
genout = req.json()["data"]["text"]
vars.lua_koboldbridge.execute_outmod()
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
vars.actions.append(genout)
update_story_chunk('last')
@ -2509,6 +2769,7 @@ def oairequest(txt, min, max):
# Deal with the response
if(req.status_code == 200):
genout = req.json()["choices"][0]["text"]
vars.lua_koboldbridge.execute_outmod()
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
vars.actions.append(genout)
update_story_chunk('last')

View File

@ -252,7 +252,7 @@ return function(_python, _bridged)
---@return boolean
function KoboldWorldInfoEntry:is_valid()
return bridged.worldinfo_u.get(self.uid) ~= nil
return bridged.vars.worldinfo_u.get(self.uid) ~= nil
end
---@return string
@ -340,7 +340,7 @@ return function(_python, _bridged)
if not check_validity(self) or type(u) ~= "number" then
return
end
local query = bridged.worldinfo_u.get(u)
local query = bridged.vars.worldinfo_u.get(u)
if query == nil or (rawget(self, "_name") == "KoboldWorldInfoFolder" and self.uid ~= query.get("folder")) then
return
end
@ -373,7 +373,7 @@ return function(_python, _bridged)
---@return boolean
function KoboldWorldInfoFolder:is_valid()
return bridged.wifolders_d.get(self.uid) ~= nil
return bridged.vars.wifolders_d.get(self.uid) ~= nil
end
---@param t KoboldWorldInfoFolder
@ -382,7 +382,7 @@ return function(_python, _bridged)
if not check_validity(t) then
return 0
end
return _python.builtins.len(bridged.worldinfo_u.get(t.uid))
return _python.builtins.len(bridged.vars.worldinfo_u.get(t.uid))
end
KoboldWorldInfoFolder_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next
@ -397,9 +397,9 @@ return function(_python, _bridged)
elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "uid" then
return rawget(t, "_uid")
elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "comment" then
return bridged.wifolders_d.get(t.uid).__getitem__("comment")
return bridged.folder_get_attr(t.uid, k)
elseif type(k) == "number" then
local query = rawget(t, "_name") == "KoboldWorldInfoFolder" and bridged.wifolders_u.get(t.uid) or bridged.worldinfo
local query = rawget(t, "_name") == "KoboldWorldInfoFolder" and bridged.vars.wifolders_u.get(t.uid) or bridged.vars.worldinfo
k = math.tointeger(k)
if k == nil or k < 1 or k > _python.builtins.len(query) then
return
@ -424,7 +424,7 @@ return function(_python, _bridged)
error("`"..rawget(t, "_name").."."..k.."` must be a string; you attempted to set it to a "..type(v))
return
end
bridged.safe_setitem(bridged.wifolders_d.get(t.uid), "comment", v)
bridged.folder_set_attr(t.uid, k, v)
return t
else
return rawset(t, k, v)
@ -450,7 +450,7 @@ return function(_python, _bridged)
if not check_validity(self) or type(u) ~= "number" then
return
end
local query = bridged.wifolders_d.get(u)
local query = bridged.vars.wifolders_d.get(u)
if query == nil then
return
end
@ -480,11 +480,11 @@ return function(_python, _bridged)
---@param t KoboldWorldInfoFolderSelector
---@return KoboldWorldInfoFolder|nil
function KoboldWorldInfoFolderSelector_mt.__index(t, k)
if not check_validity(t) or type(k) ~= "number" or math.tointeger(k) == nil or k < 1 or k > _python.builtins.len(bridged.wifolders_l) then
if not check_validity(t) or type(k) ~= "number" or math.tointeger(k) == nil or k < 1 or k > _python.builtins.len(bridged.vars.wifolders_l) then
return
end
local folder = deepcopy(KoboldWorldInfoFolder)
rawset(folder, "_uid", bridged.wifolders_l.__getitem__(k))
rawset(folder, "_uid", bridged.vars.wifolders_l.__getitem__(k))
return folder
end
@ -523,7 +523,7 @@ return function(_python, _bridged)
if not check_validity(t) then
return 0
end
return _python.builtins.len(bridged.worldinfo)
return _python.builtins.len(bridged.vars.worldinfo)
end
KoboldWorldInfo_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next
@ -577,6 +577,8 @@ return function(_python, _bridged)
end
if k == "gen_len" then
return bridged.get_gen_len()
elseif k == "numseqs" then
return bridged.get_numseqs()
elseif bridged.has_setting(k) then
return bridged.get_setting(k), true
else
@ -588,7 +590,16 @@ return function(_python, _bridged)
function KoboldSettings_mt.__newindex(t, k, v)
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
bridged.set_gen_len(v)
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
if koboldbridge.userstate == "genmod" then
error("Cannot set numseqs from a generation modifier")
return
end
bridged.set_numseqs(v)
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
maybe_save_genmod_comparison_context()
end
return bridged.set_setting(k, v)
end
return t
@ -838,7 +849,7 @@ return function(_python, _bridged)
local old_package_loaded = package.loaded
local old_package_searchers = package.searchers
---@param modname string
---@param env? table<string, any>
---@param env table<string, any>
---@param search_path? string
---@return any, string|nil
local function requirex(modname, env, search_path)