mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Connect bridge.lua to aiserver.py
Also enables the use of input modifiers and output modifiers, but not generation modifiers.
This commit is contained in:
277
aiserver.py
277
aiserver.py
@ -24,6 +24,8 @@ import argparse
|
||||
import sys
|
||||
import gc
|
||||
|
||||
import lupa
|
||||
|
||||
# KoboldAI
|
||||
import fileops
|
||||
import gensettings
|
||||
@ -31,6 +33,11 @@ from utils import debounce
|
||||
import utils
|
||||
import structures
|
||||
|
||||
|
||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||
print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr)
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Variables & Storage
|
||||
#==================================================================#
|
||||
@ -91,6 +98,10 @@ class vars:
|
||||
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
|
||||
wifolders_l = [] # List of World Info folder UIDs
|
||||
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
|
||||
lua_state = None # Lua state of the Lua scripting system
|
||||
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
||||
lua_kobold = None # `kobold` from` bridge.lua
|
||||
lua_koboldcore = None # `koboldcore` from bridge.lua
|
||||
# badwords = [] # Array of str/chr values that should be removed from output
|
||||
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
|
||||
deletewi = -1 # Temporary storage for index to delete
|
||||
@ -549,7 +560,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||
try:
|
||||
from transformers import GPTJModel
|
||||
except:
|
||||
@ -751,7 +762,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
else:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||
vars.modeldim = get_hidden_size_from_model(model)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
@ -769,7 +780,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
js = json.load(model_config)
|
||||
with(maybe_use_float16()):
|
||||
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||
vars.modeldim = get_hidden_size_from_model(model)
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
@ -780,7 +791,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
# If base HuggingFace model was chosen
|
||||
else:
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/")
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
with(maybe_use_float16()):
|
||||
@ -810,14 +821,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
# vars.badwordsids.append([vocab[key]])
|
||||
|
||||
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
|
||||
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
else:
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2Tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
||||
elif(vars.model == "OAI"):
|
||||
from transformers import GPT2Tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
# Load the TPU backend if requested
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
@ -875,6 +890,239 @@ def download():
|
||||
save.headers.set('Content-Disposition', 'attachment', filename='%s.json' % filename)
|
||||
return(save)
|
||||
|
||||
|
||||
#============================ LUA API =============================#
|
||||
|
||||
#==================================================================#
|
||||
# Event triggered when a userscript is loaded
|
||||
#==================================================================#
|
||||
def load_callback(filename):
|
||||
print(colors.PURPLE + f"Loading Userscript [USERPLACEHOLDER] <{filename}>" + colors.END)
|
||||
|
||||
#==================================================================#
|
||||
# Load all Lua scripts
|
||||
#==================================================================#
|
||||
def load_lua_scripts():
|
||||
print(colors.PURPLE + "Loading Core Script [COREPLACEHOLDER] <default.lua>" + colors.END)
|
||||
try:
|
||||
vars.lua_koboldbridge.obliterate_multiverse()
|
||||
vars.lua_koboldbridge.load_corescript("default.lua")
|
||||
vars.lua_koboldbridge.load_userscripts([], [], [])
|
||||
except lupa.LuaError as e:
|
||||
print(e, file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
#==================================================================#
|
||||
# Decode tokens into a string using current tokenizer
|
||||
#==================================================================#
|
||||
def lua_decode(tokens):
|
||||
assert type(tokens) is list
|
||||
return tokenizer.decode(tokens)
|
||||
|
||||
#==================================================================#
|
||||
# Encode string into list of token IDs using current tokenizer
|
||||
#==================================================================#
|
||||
def lua_encode(string):
|
||||
assert type(string) is str
|
||||
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
||||
|
||||
#==================================================================#
|
||||
# Get property of a world info entry given its UID and property name
|
||||
#==================================================================#
|
||||
def lua_get_attr(uid, k):
|
||||
assert type(uid) is int and type(k) is str
|
||||
if(uid in vars.worldinfo_u and k in (
|
||||
"key",
|
||||
"keysecondary",
|
||||
"content",
|
||||
"comment",
|
||||
"folder",
|
||||
"num",
|
||||
"selective",
|
||||
"constant",
|
||||
"uid",
|
||||
)):
|
||||
return vars.worldinfo_u[uid][k]
|
||||
|
||||
#==================================================================#
|
||||
# Set property of a world info entry given its UID, property name and new value
|
||||
#==================================================================#
|
||||
def lua_set_attr(uid, k, v):
|
||||
assert type(uid) is int and type(k) is str
|
||||
assert uid in vars.worldinfo_u and k in (
|
||||
"key",
|
||||
"keysecondary",
|
||||
"content",
|
||||
"comment",
|
||||
"selective",
|
||||
"constant",
|
||||
)
|
||||
if(type(vars.worldinfo_u[uid][k]) is int and type(v) is float):
|
||||
v = int(v)
|
||||
assert type(vars.worldinfo_u[uid][k]) is type(v)
|
||||
vars.worldinfo_u[uid][k] = v
|
||||
|
||||
#==================================================================#
|
||||
# Get property of a world info folder given its UID and property name
|
||||
#==================================================================#
|
||||
def lua_folder_get_attr(uid, k):
|
||||
assert type(uid) is int and type(k) is str
|
||||
if(uid in vars.wifolders_d and k in (
|
||||
"comment",
|
||||
)):
|
||||
return vars.wifolders_d[uid][k]
|
||||
|
||||
#==================================================================#
|
||||
# Set property of a world info folder given its UID, property name and new value
|
||||
#==================================================================#
|
||||
def lua_folder_set_attr(uid, k, v):
|
||||
assert type(uid) is int and type(k) is str
|
||||
assert uid in vars.wifolders_d and k in (
|
||||
"comment",
|
||||
)
|
||||
if(type(vars.wifolders_d[uid][k]) is int and type(v) is float):
|
||||
v = int(v)
|
||||
assert type(vars.wifolders_d[uid][k]) is type(v)
|
||||
vars.wifolders_d[uid][k] = v
|
||||
|
||||
#==================================================================#
|
||||
# Get the "Amount to Generate"
|
||||
#==================================================================#
|
||||
def lua_get_gen_len():
|
||||
return vars.genamt
|
||||
|
||||
#==================================================================#
|
||||
# Set the "Amount to Generate"
|
||||
#==================================================================#
|
||||
def lua_set_gen_len(genamt):
|
||||
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
|
||||
vars.genamt = int(genamt)
|
||||
|
||||
#==================================================================#
|
||||
# Get the "Gens Per Action"
|
||||
#==================================================================#
|
||||
def lua_get_numseqs():
|
||||
return vars.numseqs
|
||||
|
||||
#==================================================================#
|
||||
# Set the "Gens Per Action"
|
||||
#==================================================================#
|
||||
def lua_set_numseqs(numseqs):
|
||||
assert type(numseqs) in (int, float) and numseqs >= 1
|
||||
vars.genamt = int(numseqs)
|
||||
|
||||
#==================================================================#
|
||||
# Check if a setting exists with the given name
|
||||
#==================================================================#
|
||||
def lua_has_setting(setting):
|
||||
return setting in (
|
||||
"settemp",
|
||||
"settopp",
|
||||
"settopk",
|
||||
"settfs",
|
||||
"setreppen",
|
||||
"setoutput",
|
||||
"settknmax",
|
||||
"anotedepth",
|
||||
"setwidepth",
|
||||
"setuseprompt",
|
||||
"setadventure",
|
||||
"frmttriminc",
|
||||
"frmtrmblln",
|
||||
"frmtrmspch",
|
||||
"frmtadsnsp",
|
||||
"singleline",
|
||||
)
|
||||
|
||||
#==================================================================#
|
||||
# Return the setting with the given name if it exists
|
||||
#==================================================================#
|
||||
def lua_get_setting(setting):
|
||||
if(setting == "settemp"): return vars.temp
|
||||
if(setting == "settopp"): return vars.top_p
|
||||
if(setting == "settopk"): return vars.top_k
|
||||
if(setting == "settfs"): return vars.tfs
|
||||
if(setting == "setreppen"): return vars.rep_pen
|
||||
if(setting == "settknmax"): return vars.max_length
|
||||
if(setting == "anotedepth"): return vars.andepth
|
||||
if(setting == "setwidepth"): return vars.widepth
|
||||
if(setting == "setuseprompt"): return vars.useprompt
|
||||
if(setting == "setadventure"): return vars.adventure
|
||||
if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "singleline"): return vars.formatoptns["frmttriminc"]
|
||||
|
||||
#==================================================================#
|
||||
# Set the setting with the given name if it exists
|
||||
#==================================================================#
|
||||
def lua_set_setting(setting, v):
|
||||
actual_type = type(lua_get_setting(setting))
|
||||
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
|
||||
v = actual_type(v)
|
||||
print(colors.PURPLE + f"[USERPLACEHOLDER] set {setting} to {v}" + colors.END)
|
||||
if(setting == "setadventure" and v):
|
||||
vars.actionmode = 1
|
||||
get_message({'cmd': setting, 'data': v})
|
||||
|
||||
#==================================================================#
|
||||
# Get contents of memory
|
||||
#==================================================================#
|
||||
def lua_get_memory():
|
||||
return vars.memory
|
||||
|
||||
#==================================================================#
|
||||
# Set contents of memory
|
||||
#==================================================================#
|
||||
def lua_set_memory(m):
|
||||
assert type(m) is str
|
||||
vars.memory = m
|
||||
|
||||
#==================================================================#
|
||||
# Lua runtime startup
|
||||
#==================================================================#
|
||||
|
||||
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="")
|
||||
|
||||
# Set up Lua state
|
||||
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
|
||||
|
||||
# Load bridge.lua
|
||||
bridged = {
|
||||
"corescript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "corescripts"),
|
||||
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "userscripts"),
|
||||
"lib_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs"),
|
||||
"load_callback": load_callback,
|
||||
"decode": lua_decode,
|
||||
"encode": lua_encode,
|
||||
"get_attr": lua_get_attr,
|
||||
"set_attr": lua_set_attr,
|
||||
"get_gen_len": lua_get_gen_len,
|
||||
"set_gen_len": lua_set_gen_len,
|
||||
"get_memory": lua_get_memory,
|
||||
"set_memory": lua_set_memory,
|
||||
"get_numseqs": lua_get_numseqs,
|
||||
"set_numseqs": lua_set_numseqs,
|
||||
"get_setting": lua_get_setting,
|
||||
"set_setting": lua_set_setting,
|
||||
"vars": vars,
|
||||
}
|
||||
try:
|
||||
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
|
||||
vars.lua_state.globals().python,
|
||||
bridged,
|
||||
)
|
||||
except lupa.LuaError as e:
|
||||
print(colors.RED + "ERROR!" + colors.END)
|
||||
print(e, file=sys.stderr)
|
||||
exit(1)
|
||||
print(colors.GREEN + "OK!" + colors.END)
|
||||
|
||||
# Load scripts
|
||||
load_lua_scripts()
|
||||
|
||||
|
||||
#============================ METHODS =============================#
|
||||
|
||||
#==================================================================#
|
||||
@ -1314,6 +1562,9 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
||||
vars.recentedit = False
|
||||
vars.actionmode = actionmode
|
||||
|
||||
# Run the core script's input modifier
|
||||
vars.lua_koboldbridge.execute_inmod()
|
||||
|
||||
# "Action" mode
|
||||
if(actionmode == 1):
|
||||
data = data.strip().lstrip('>')
|
||||
@ -1339,6 +1590,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
||||
calcsubmit(data) # Run the first action through the generator
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
else:
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
refresh_story()
|
||||
set_aibusy(0)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
@ -1360,6 +1612,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
||||
calcsubmit(data)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
else:
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
set_aibusy(0)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
|
||||
@ -1700,6 +1953,8 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
# Need to manually strip and decode tokens if we're not using a pipeline
|
||||
#already_generated = -(len(gen_in[0]) - len(tokens))
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
||||
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
@ -1800,6 +2055,8 @@ def sendtocolab(txt, min, max):
|
||||
else:
|
||||
genout = js["seqs"]
|
||||
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0])
|
||||
else:
|
||||
@ -1881,6 +2138,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
|
||||
genout = [{"generated_text": txt} for txt in genout]
|
||||
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
@ -2459,6 +2718,7 @@ def ikrequest(txt):
|
||||
# Deal with the response
|
||||
if(req.status_code == 200):
|
||||
genout = req.json()["data"]["text"]
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
vars.actions.append(genout)
|
||||
update_story_chunk('last')
|
||||
@ -2509,6 +2769,7 @@ def oairequest(txt, min, max):
|
||||
# Deal with the response
|
||||
if(req.status_code == 200):
|
||||
genout = req.json()["choices"][0]["text"]
|
||||
vars.lua_koboldbridge.execute_outmod()
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
vars.actions.append(genout)
|
||||
update_story_chunk('last')
|
||||
|
Reference in New Issue
Block a user