mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
Connect bridge.lua to aiserver.py
Also enables the use of input modifiers and output modifiers, but not generation modifiers.
This commit is contained in:
parent
68685698a4
commit
e289a0d360
277
aiserver.py
277
aiserver.py
@ -24,6 +24,8 @@ import argparse
|
|||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
import lupa
|
||||||
|
|
||||||
# KoboldAI
|
# KoboldAI
|
||||||
import fileops
|
import fileops
|
||||||
import gensettings
|
import gensettings
|
||||||
@ -31,6 +33,11 @@ from utils import debounce
|
|||||||
import utils
|
import utils
|
||||||
import structures
|
import structures
|
||||||
|
|
||||||
|
|
||||||
|
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||||
|
print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Variables & Storage
|
# Variables & Storage
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -91,6 +98,10 @@ class vars:
|
|||||||
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
|
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
|
||||||
wifolders_l = [] # List of World Info folder UIDs
|
wifolders_l = [] # List of World Info folder UIDs
|
||||||
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
|
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
|
||||||
|
lua_state = None # Lua state of the Lua scripting system
|
||||||
|
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
||||||
|
lua_kobold = None # `kobold` from` bridge.lua
|
||||||
|
lua_koboldcore = None # `koboldcore` from bridge.lua
|
||||||
# badwords = [] # Array of str/chr values that should be removed from output
|
# badwords = [] # Array of str/chr values that should be removed from output
|
||||||
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
|
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
|
||||||
deletewi = -1 # Temporary storage for index to delete
|
deletewi = -1 # Temporary storage for index to delete
|
||||||
@ -549,7 +560,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
|||||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
||||||
try:
|
try:
|
||||||
from transformers import GPTJModel
|
from transformers import GPTJModel
|
||||||
except:
|
except:
|
||||||
@ -751,7 +762,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
else:
|
else:
|
||||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
if(vars.usegpu):
|
if(vars.usegpu):
|
||||||
@ -769,7 +780,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
js = json.load(model_config)
|
js = json.load(model_config)
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||||
if(vars.hascuda and vars.usegpu):
|
if(vars.hascuda and vars.usegpu):
|
||||||
@ -780,7 +791,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
# If base HuggingFace model was chosen
|
# If base HuggingFace model was chosen
|
||||||
else:
|
else:
|
||||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
if(vars.usegpu):
|
if(vars.usegpu):
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
@ -810,14 +821,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
# vars.badwordsids.append([vocab[key]])
|
# vars.badwordsids.append([vocab[key]])
|
||||||
|
|
||||||
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
|
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, vars.model, colors.END))
|
||||||
|
|
||||||
|
else:
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
else:
|
else:
|
||||||
# If we're running Colab or OAI, we still need a tokenizer.
|
# If we're running Colab or OAI, we still need a tokenizer.
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
# Load the TPU backend if requested
|
# Load the TPU backend if requested
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
@ -875,6 +890,239 @@ def download():
|
|||||||
save.headers.set('Content-Disposition', 'attachment', filename='%s.json' % filename)
|
save.headers.set('Content-Disposition', 'attachment', filename='%s.json' % filename)
|
||||||
return(save)
|
return(save)
|
||||||
|
|
||||||
|
|
||||||
|
#============================ LUA API =============================#
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Event triggered when a userscript is loaded
|
||||||
|
#==================================================================#
|
||||||
|
def load_callback(filename):
|
||||||
|
print(colors.PURPLE + f"Loading Userscript [USERPLACEHOLDER] <{filename}>" + colors.END)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Load all Lua scripts
|
||||||
|
#==================================================================#
|
||||||
|
def load_lua_scripts():
|
||||||
|
print(colors.PURPLE + "Loading Core Script [COREPLACEHOLDER] <default.lua>" + colors.END)
|
||||||
|
try:
|
||||||
|
vars.lua_koboldbridge.obliterate_multiverse()
|
||||||
|
vars.lua_koboldbridge.load_corescript("default.lua")
|
||||||
|
vars.lua_koboldbridge.load_userscripts([], [], [])
|
||||||
|
except lupa.LuaError as e:
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Decode tokens into a string using current tokenizer
|
||||||
|
#==================================================================#
|
||||||
|
def lua_decode(tokens):
|
||||||
|
assert type(tokens) is list
|
||||||
|
return tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Encode string into list of token IDs using current tokenizer
|
||||||
|
#==================================================================#
|
||||||
|
def lua_encode(string):
|
||||||
|
assert type(string) is str
|
||||||
|
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Get property of a world info entry given its UID and property name
|
||||||
|
#==================================================================#
|
||||||
|
def lua_get_attr(uid, k):
|
||||||
|
assert type(uid) is int and type(k) is str
|
||||||
|
if(uid in vars.worldinfo_u and k in (
|
||||||
|
"key",
|
||||||
|
"keysecondary",
|
||||||
|
"content",
|
||||||
|
"comment",
|
||||||
|
"folder",
|
||||||
|
"num",
|
||||||
|
"selective",
|
||||||
|
"constant",
|
||||||
|
"uid",
|
||||||
|
)):
|
||||||
|
return vars.worldinfo_u[uid][k]
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Set property of a world info entry given its UID, property name and new value
|
||||||
|
#==================================================================#
|
||||||
|
def lua_set_attr(uid, k, v):
|
||||||
|
assert type(uid) is int and type(k) is str
|
||||||
|
assert uid in vars.worldinfo_u and k in (
|
||||||
|
"key",
|
||||||
|
"keysecondary",
|
||||||
|
"content",
|
||||||
|
"comment",
|
||||||
|
"selective",
|
||||||
|
"constant",
|
||||||
|
)
|
||||||
|
if(type(vars.worldinfo_u[uid][k]) is int and type(v) is float):
|
||||||
|
v = int(v)
|
||||||
|
assert type(vars.worldinfo_u[uid][k]) is type(v)
|
||||||
|
vars.worldinfo_u[uid][k] = v
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Get property of a world info folder given its UID and property name
|
||||||
|
#==================================================================#
|
||||||
|
def lua_folder_get_attr(uid, k):
|
||||||
|
assert type(uid) is int and type(k) is str
|
||||||
|
if(uid in vars.wifolders_d and k in (
|
||||||
|
"comment",
|
||||||
|
)):
|
||||||
|
return vars.wifolders_d[uid][k]
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Set property of a world info folder given its UID, property name and new value
|
||||||
|
#==================================================================#
|
||||||
|
def lua_folder_set_attr(uid, k, v):
|
||||||
|
assert type(uid) is int and type(k) is str
|
||||||
|
assert uid in vars.wifolders_d and k in (
|
||||||
|
"comment",
|
||||||
|
)
|
||||||
|
if(type(vars.wifolders_d[uid][k]) is int and type(v) is float):
|
||||||
|
v = int(v)
|
||||||
|
assert type(vars.wifolders_d[uid][k]) is type(v)
|
||||||
|
vars.wifolders_d[uid][k] = v
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Get the "Amount to Generate"
|
||||||
|
#==================================================================#
|
||||||
|
def lua_get_gen_len():
|
||||||
|
return vars.genamt
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Set the "Amount to Generate"
|
||||||
|
#==================================================================#
|
||||||
|
def lua_set_gen_len(genamt):
|
||||||
|
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
|
||||||
|
vars.genamt = int(genamt)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Get the "Gens Per Action"
|
||||||
|
#==================================================================#
|
||||||
|
def lua_get_numseqs():
|
||||||
|
return vars.numseqs
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Set the "Gens Per Action"
|
||||||
|
#==================================================================#
|
||||||
|
def lua_set_numseqs(numseqs):
|
||||||
|
assert type(numseqs) in (int, float) and numseqs >= 1
|
||||||
|
vars.genamt = int(numseqs)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Check if a setting exists with the given name
|
||||||
|
#==================================================================#
|
||||||
|
def lua_has_setting(setting):
|
||||||
|
return setting in (
|
||||||
|
"settemp",
|
||||||
|
"settopp",
|
||||||
|
"settopk",
|
||||||
|
"settfs",
|
||||||
|
"setreppen",
|
||||||
|
"setoutput",
|
||||||
|
"settknmax",
|
||||||
|
"anotedepth",
|
||||||
|
"setwidepth",
|
||||||
|
"setuseprompt",
|
||||||
|
"setadventure",
|
||||||
|
"frmttriminc",
|
||||||
|
"frmtrmblln",
|
||||||
|
"frmtrmspch",
|
||||||
|
"frmtadsnsp",
|
||||||
|
"singleline",
|
||||||
|
)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Return the setting with the given name if it exists
|
||||||
|
#==================================================================#
|
||||||
|
def lua_get_setting(setting):
|
||||||
|
if(setting == "settemp"): return vars.temp
|
||||||
|
if(setting == "settopp"): return vars.top_p
|
||||||
|
if(setting == "settopk"): return vars.top_k
|
||||||
|
if(setting == "settfs"): return vars.tfs
|
||||||
|
if(setting == "setreppen"): return vars.rep_pen
|
||||||
|
if(setting == "settknmax"): return vars.max_length
|
||||||
|
if(setting == "anotedepth"): return vars.andepth
|
||||||
|
if(setting == "setwidepth"): return vars.widepth
|
||||||
|
if(setting == "setuseprompt"): return vars.useprompt
|
||||||
|
if(setting == "setadventure"): return vars.adventure
|
||||||
|
if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"]
|
||||||
|
if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"]
|
||||||
|
if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"]
|
||||||
|
if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"]
|
||||||
|
if(setting == "singleline"): return vars.formatoptns["frmttriminc"]
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Set the setting with the given name if it exists
|
||||||
|
#==================================================================#
|
||||||
|
def lua_set_setting(setting, v):
|
||||||
|
actual_type = type(lua_get_setting(setting))
|
||||||
|
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
|
||||||
|
v = actual_type(v)
|
||||||
|
print(colors.PURPLE + f"[USERPLACEHOLDER] set {setting} to {v}" + colors.END)
|
||||||
|
if(setting == "setadventure" and v):
|
||||||
|
vars.actionmode = 1
|
||||||
|
get_message({'cmd': setting, 'data': v})
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Get contents of memory
|
||||||
|
#==================================================================#
|
||||||
|
def lua_get_memory():
|
||||||
|
return vars.memory
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Set contents of memory
|
||||||
|
#==================================================================#
|
||||||
|
def lua_set_memory(m):
|
||||||
|
assert type(m) is str
|
||||||
|
vars.memory = m
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Lua runtime startup
|
||||||
|
#==================================================================#
|
||||||
|
|
||||||
|
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="")
|
||||||
|
|
||||||
|
# Set up Lua state
|
||||||
|
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
|
||||||
|
|
||||||
|
# Load bridge.lua
|
||||||
|
bridged = {
|
||||||
|
"corescript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "corescripts"),
|
||||||
|
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts", "userscripts"),
|
||||||
|
"lib_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs"),
|
||||||
|
"load_callback": load_callback,
|
||||||
|
"decode": lua_decode,
|
||||||
|
"encode": lua_encode,
|
||||||
|
"get_attr": lua_get_attr,
|
||||||
|
"set_attr": lua_set_attr,
|
||||||
|
"get_gen_len": lua_get_gen_len,
|
||||||
|
"set_gen_len": lua_set_gen_len,
|
||||||
|
"get_memory": lua_get_memory,
|
||||||
|
"set_memory": lua_set_memory,
|
||||||
|
"get_numseqs": lua_get_numseqs,
|
||||||
|
"set_numseqs": lua_set_numseqs,
|
||||||
|
"get_setting": lua_get_setting,
|
||||||
|
"set_setting": lua_set_setting,
|
||||||
|
"vars": vars,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
|
||||||
|
vars.lua_state.globals().python,
|
||||||
|
bridged,
|
||||||
|
)
|
||||||
|
except lupa.LuaError as e:
|
||||||
|
print(colors.RED + "ERROR!" + colors.END)
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
print(colors.GREEN + "OK!" + colors.END)
|
||||||
|
|
||||||
|
# Load scripts
|
||||||
|
load_lua_scripts()
|
||||||
|
|
||||||
|
|
||||||
#============================ METHODS =============================#
|
#============================ METHODS =============================#
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -1314,6 +1562,9 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
vars.recentedit = False
|
vars.recentedit = False
|
||||||
vars.actionmode = actionmode
|
vars.actionmode = actionmode
|
||||||
|
|
||||||
|
# Run the core script's input modifier
|
||||||
|
vars.lua_koboldbridge.execute_inmod()
|
||||||
|
|
||||||
# "Action" mode
|
# "Action" mode
|
||||||
if(actionmode == 1):
|
if(actionmode == 1):
|
||||||
data = data.strip().lstrip('>')
|
data = data.strip().lstrip('>')
|
||||||
@ -1339,6 +1590,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
calcsubmit(data) # Run the first action through the generator
|
calcsubmit(data) # Run the first action through the generator
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
else:
|
else:
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
refresh_story()
|
refresh_story()
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
@ -1360,6 +1612,7 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
|||||||
calcsubmit(data)
|
calcsubmit(data)
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
else:
|
else:
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
|
|
||||||
@ -1700,6 +1953,8 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||||||
# Need to manually strip and decode tokens if we're not using a pipeline
|
# Need to manually strip and decode tokens if we're not using a pipeline
|
||||||
#already_generated = -(len(gen_in[0]) - len(tokens))
|
#already_generated = -(len(gen_in[0]) - len(tokens))
|
||||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
||||||
|
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
@ -1800,6 +2055,8 @@ def sendtocolab(txt, min, max):
|
|||||||
else:
|
else:
|
||||||
genout = js["seqs"]
|
genout = js["seqs"]
|
||||||
|
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0])
|
genresult(genout[0])
|
||||||
else:
|
else:
|
||||||
@ -1881,6 +2138,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
|
|
||||||
genout = [{"generated_text": txt} for txt in genout]
|
genout = [{"generated_text": txt} for txt in genout]
|
||||||
|
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
else:
|
else:
|
||||||
@ -2459,6 +2718,7 @@ def ikrequest(txt):
|
|||||||
# Deal with the response
|
# Deal with the response
|
||||||
if(req.status_code == 200):
|
if(req.status_code == 200):
|
||||||
genout = req.json()["data"]["text"]
|
genout = req.json()["data"]["text"]
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||||
vars.actions.append(genout)
|
vars.actions.append(genout)
|
||||||
update_story_chunk('last')
|
update_story_chunk('last')
|
||||||
@ -2509,6 +2769,7 @@ def oairequest(txt, min, max):
|
|||||||
# Deal with the response
|
# Deal with the response
|
||||||
if(req.status_code == 200):
|
if(req.status_code == 200):
|
||||||
genout = req.json()["choices"][0]["text"]
|
genout = req.json()["choices"][0]["text"]
|
||||||
|
vars.lua_koboldbridge.execute_outmod()
|
||||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||||
vars.actions.append(genout)
|
vars.actions.append(genout)
|
||||||
update_story_chunk('last')
|
update_story_chunk('last')
|
||||||
|
35
bridge.lua
35
bridge.lua
@ -252,7 +252,7 @@ return function(_python, _bridged)
|
|||||||
|
|
||||||
---@return boolean
|
---@return boolean
|
||||||
function KoboldWorldInfoEntry:is_valid()
|
function KoboldWorldInfoEntry:is_valid()
|
||||||
return bridged.worldinfo_u.get(self.uid) ~= nil
|
return bridged.vars.worldinfo_u.get(self.uid) ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@return string
|
---@return string
|
||||||
@ -340,7 +340,7 @@ return function(_python, _bridged)
|
|||||||
if not check_validity(self) or type(u) ~= "number" then
|
if not check_validity(self) or type(u) ~= "number" then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local query = bridged.worldinfo_u.get(u)
|
local query = bridged.vars.worldinfo_u.get(u)
|
||||||
if query == nil or (rawget(self, "_name") == "KoboldWorldInfoFolder" and self.uid ~= query.get("folder")) then
|
if query == nil or (rawget(self, "_name") == "KoboldWorldInfoFolder" and self.uid ~= query.get("folder")) then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -373,7 +373,7 @@ return function(_python, _bridged)
|
|||||||
|
|
||||||
---@return boolean
|
---@return boolean
|
||||||
function KoboldWorldInfoFolder:is_valid()
|
function KoboldWorldInfoFolder:is_valid()
|
||||||
return bridged.wifolders_d.get(self.uid) ~= nil
|
return bridged.vars.wifolders_d.get(self.uid) ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param t KoboldWorldInfoFolder
|
---@param t KoboldWorldInfoFolder
|
||||||
@ -382,7 +382,7 @@ return function(_python, _bridged)
|
|||||||
if not check_validity(t) then
|
if not check_validity(t) then
|
||||||
return 0
|
return 0
|
||||||
end
|
end
|
||||||
return _python.builtins.len(bridged.worldinfo_u.get(t.uid))
|
return _python.builtins.len(bridged.vars.worldinfo_u.get(t.uid))
|
||||||
end
|
end
|
||||||
|
|
||||||
KoboldWorldInfoFolder_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next
|
KoboldWorldInfoFolder_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next
|
||||||
@ -397,9 +397,9 @@ return function(_python, _bridged)
|
|||||||
elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "uid" then
|
elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "uid" then
|
||||||
return rawget(t, "_uid")
|
return rawget(t, "_uid")
|
||||||
elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "comment" then
|
elseif rawget(t, "_name") == "KoboldWorldInfoFolder" and k == "comment" then
|
||||||
return bridged.wifolders_d.get(t.uid).__getitem__("comment")
|
return bridged.folder_get_attr(t.uid, k)
|
||||||
elseif type(k) == "number" then
|
elseif type(k) == "number" then
|
||||||
local query = rawget(t, "_name") == "KoboldWorldInfoFolder" and bridged.wifolders_u.get(t.uid) or bridged.worldinfo
|
local query = rawget(t, "_name") == "KoboldWorldInfoFolder" and bridged.vars.wifolders_u.get(t.uid) or bridged.vars.worldinfo
|
||||||
k = math.tointeger(k)
|
k = math.tointeger(k)
|
||||||
if k == nil or k < 1 or k > _python.builtins.len(query) then
|
if k == nil or k < 1 or k > _python.builtins.len(query) then
|
||||||
return
|
return
|
||||||
@ -424,7 +424,7 @@ return function(_python, _bridged)
|
|||||||
error("`"..rawget(t, "_name").."."..k.."` must be a string; you attempted to set it to a "..type(v))
|
error("`"..rawget(t, "_name").."."..k.."` must be a string; you attempted to set it to a "..type(v))
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
bridged.safe_setitem(bridged.wifolders_d.get(t.uid), "comment", v)
|
bridged.folder_set_attr(t.uid, k, v)
|
||||||
return t
|
return t
|
||||||
else
|
else
|
||||||
return rawset(t, k, v)
|
return rawset(t, k, v)
|
||||||
@ -450,7 +450,7 @@ return function(_python, _bridged)
|
|||||||
if not check_validity(self) or type(u) ~= "number" then
|
if not check_validity(self) or type(u) ~= "number" then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local query = bridged.wifolders_d.get(u)
|
local query = bridged.vars.wifolders_d.get(u)
|
||||||
if query == nil then
|
if query == nil then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -480,11 +480,11 @@ return function(_python, _bridged)
|
|||||||
---@param t KoboldWorldInfoFolderSelector
|
---@param t KoboldWorldInfoFolderSelector
|
||||||
---@return KoboldWorldInfoFolder|nil
|
---@return KoboldWorldInfoFolder|nil
|
||||||
function KoboldWorldInfoFolderSelector_mt.__index(t, k)
|
function KoboldWorldInfoFolderSelector_mt.__index(t, k)
|
||||||
if not check_validity(t) or type(k) ~= "number" or math.tointeger(k) == nil or k < 1 or k > _python.builtins.len(bridged.wifolders_l) then
|
if not check_validity(t) or type(k) ~= "number" or math.tointeger(k) == nil or k < 1 or k > _python.builtins.len(bridged.vars.wifolders_l) then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local folder = deepcopy(KoboldWorldInfoFolder)
|
local folder = deepcopy(KoboldWorldInfoFolder)
|
||||||
rawset(folder, "_uid", bridged.wifolders_l.__getitem__(k))
|
rawset(folder, "_uid", bridged.vars.wifolders_l.__getitem__(k))
|
||||||
return folder
|
return folder
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -523,7 +523,7 @@ return function(_python, _bridged)
|
|||||||
if not check_validity(t) then
|
if not check_validity(t) then
|
||||||
return 0
|
return 0
|
||||||
end
|
end
|
||||||
return _python.builtins.len(bridged.worldinfo)
|
return _python.builtins.len(bridged.vars.worldinfo)
|
||||||
end
|
end
|
||||||
|
|
||||||
KoboldWorldInfo_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next
|
KoboldWorldInfo_mt._kobold_next = KoboldWorldInfoEntry_mt._kobold_next
|
||||||
@ -577,6 +577,8 @@ return function(_python, _bridged)
|
|||||||
end
|
end
|
||||||
if k == "gen_len" then
|
if k == "gen_len" then
|
||||||
return bridged.get_gen_len()
|
return bridged.get_gen_len()
|
||||||
|
elseif k == "numseqs" then
|
||||||
|
return bridged.get_numseqs()
|
||||||
elseif bridged.has_setting(k) then
|
elseif bridged.has_setting(k) then
|
||||||
return bridged.get_setting(k), true
|
return bridged.get_setting(k), true
|
||||||
else
|
else
|
||||||
@ -588,7 +590,16 @@ return function(_python, _bridged)
|
|||||||
function KoboldSettings_mt.__newindex(t, k, v)
|
function KoboldSettings_mt.__newindex(t, k, v)
|
||||||
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
|
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
|
||||||
bridged.set_gen_len(v)
|
bridged.set_gen_len(v)
|
||||||
|
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
|
||||||
|
if koboldbridge.userstate == "genmod" then
|
||||||
|
error("Cannot set numseqs from a generation modifier")
|
||||||
|
return
|
||||||
|
end
|
||||||
|
bridged.set_numseqs(v)
|
||||||
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
|
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
|
||||||
|
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
|
||||||
|
maybe_save_genmod_comparison_context()
|
||||||
|
end
|
||||||
return bridged.set_setting(k, v)
|
return bridged.set_setting(k, v)
|
||||||
end
|
end
|
||||||
return t
|
return t
|
||||||
@ -838,7 +849,7 @@ return function(_python, _bridged)
|
|||||||
local old_package_loaded = package.loaded
|
local old_package_loaded = package.loaded
|
||||||
local old_package_searchers = package.searchers
|
local old_package_searchers = package.searchers
|
||||||
---@param modname string
|
---@param modname string
|
||||||
---@param env? table<string, any>
|
---@param env table<string, any>
|
||||||
---@param search_path? string
|
---@param search_path? string
|
||||||
---@return any, string|nil
|
---@return any, string|nil
|
||||||
local function requirex(modname, env, search_path)
|
local function requirex(modname, env, search_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user