Fix Lua settings API

This commit is contained in:
Gnome Ann 2021-12-11 17:01:41 -05:00
parent f8aa578f41
commit 3327f1b471
2 changed files with 73 additions and 12 deletions

View File

@ -994,14 +994,15 @@ def lua_folder_set_attr(uid, k, v):
#==================================================================# #==================================================================#
# Get the "Amount to Generate" # Get the "Amount to Generate"
#==================================================================# #==================================================================#
def lua_get_gen_len(): def lua_get_genamt():
return vars.genamt return vars.genamt
#==================================================================# #==================================================================#
# Set the "Amount to Generate" # Set the "Amount to Generate"
#==================================================================# #==================================================================#
def lua_set_gen_len(genamt): def lua_set_genamt(genamt):
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0 assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
print(colors.PURPLE + f"[USERPLACEHOLDER] set genamt to {int(genamt)}" + colors.END)
vars.genamt = int(genamt) vars.genamt = int(genamt)
#==================================================================# #==================================================================#
@ -1015,6 +1016,7 @@ def lua_get_numseqs():
#==================================================================# #==================================================================#
def lua_set_numseqs(numseqs): def lua_set_numseqs(numseqs):
assert type(numseqs) in (int, float) and numseqs >= 1 assert type(numseqs) in (int, float) and numseqs >= 1
print(colors.PURPLE + f"[USERPLACEHOLDER] set numseqs to {int(numseqs)}" + colors.END)
vars.genamt = int(numseqs) vars.genamt = int(numseqs)
#==================================================================# #==================================================================#
@ -1027,7 +1029,6 @@ def lua_has_setting(setting):
"settopk", "settopk",
"settfs", "settfs",
"setreppen", "setreppen",
"setoutput",
"settknmax", "settknmax",
"anotedepth", "anotedepth",
"setwidepth", "setwidepth",
@ -1070,7 +1071,21 @@ def lua_set_setting(setting, v):
print(colors.PURPLE + f"[USERPLACEHOLDER] set {setting} to {v}" + colors.END) print(colors.PURPLE + f"[USERPLACEHOLDER] set {setting} to {v}" + colors.END)
if(setting == "setadventure" and v): if(setting == "setadventure" and v):
vars.actionmode = 1 vars.actionmode = 1
get_message({'cmd': setting, 'data': v}) if(setting == "settemp"): vars.temp = v
if(setting == "settopp"): vars.top_p = v
if(setting == "settopk"): vars.top_k = v
if(setting == "settfs"): vars.tfs = v
if(setting == "setreppen"): vars.rep_pen = v
if(setting == "settknmax"): vars.max_length = v
if(setting == "anotedepth"): vars.andepth = v
if(setting == "setwidepth"): vars.widepth = v
if(setting == "setuseprompt"): vars.useprompt = v
if(setting == "setadventure"): vars.adventure = v
if(setting == "frmttriminc"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtrmblln"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtrmspch"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtadsnsp"): vars.formatoptns["frmttriminc"] = v
if(setting == "singleline"): vars.formatoptns["frmttriminc"] = v
#==================================================================# #==================================================================#
# Get contents of memory # Get contents of memory
@ -1085,6 +1100,25 @@ def lua_set_memory(m):
assert type(m) is str assert type(m) is str
vars.memory = m vars.memory = m
#==================================================================#
# Save settings and send them to client
#==================================================================#
def lua_resend_settings():
settingschanged()
refresh_settings()
#==================================================================#
#
#==================================================================#
def execute_inmod():
vars.lua_koboldbridge.execute_inmod()
def execute_genmod():
vars.lua_koboldbridge.execute_genmod()
def execute_outmod():
vars.lua_koboldbridge.execute_outmod()
#==================================================================# #==================================================================#
# Lua runtime startup # Lua runtime startup
#==================================================================# #==================================================================#
@ -1104,14 +1138,16 @@ bridged = {
"encode": lua_encode, "encode": lua_encode,
"get_attr": lua_get_attr, "get_attr": lua_get_attr,
"set_attr": lua_set_attr, "set_attr": lua_set_attr,
"get_gen_len": lua_get_gen_len, "get_genamt": lua_get_genamt,
"set_gen_len": lua_set_gen_len, "set_genamt": lua_set_genamt,
"get_memory": lua_get_memory, "get_memory": lua_get_memory,
"set_memory": lua_set_memory, "set_memory": lua_set_memory,
"get_numseqs": lua_get_numseqs, "get_numseqs": lua_get_numseqs,
"set_numseqs": lua_set_numseqs, "set_numseqs": lua_set_numseqs,
"has_setting": lua_has_setting,
"get_setting": lua_get_setting, "get_setting": lua_get_setting,
"set_setting": lua_set_setting, "set_setting": lua_set_setting,
"resend_settings": lua_resend_settings,
"vars": vars, "vars": vars,
} }
try: try:

View File

@ -182,6 +182,7 @@ return function(_python, _bridged)
koboldbridge.genmod_comparison_context = nil koboldbridge.genmod_comparison_context = nil
koboldbridge.regeneration_required = false koboldbridge.regeneration_required = false
koboldbridge.resend_settings_required = false
koboldbridge.generating = true koboldbridge.generating = true
koboldbridge.userstate = "inmod" koboldbridge.userstate = "inmod"
@ -547,6 +548,23 @@ return function(_python, _bridged)
local _ = {} local _ = {}
---@class KoboldSettings : KoboldSettings_base ---@class KoboldSettings : KoboldSettings_base
---@field numseqs integer
---@field genamt integer
---@field settemp number
---@field settopp number
---@field settopk integer
---@field settfs number
---@field setreppen number
---@field settknmax integer
---@field anotedepth integer
---@field setwidepth integer
---@field setuseprompt boolean
---@field setadventure boolean
---@field frmttriminc boolean
---@field frmtrmblln boolean
---@field frmtrmspch boolean
---@field frmtadsnsp boolean
---@field singleline boolean
local KoboldSettings = setmetatable({ local KoboldSettings = setmetatable({
_name = "KoboldSettings", _name = "KoboldSettings",
}, metawrapper) }, metawrapper)
@ -576,10 +594,10 @@ return function(_python, _bridged)
if type(k) ~= "string" then if type(k) ~= "string" then
return return
end end
if k == "gen_len" then if k == "genamt" then
return bridged.get_gen_len() return math.tointeger(bridged.get_genamt()), true
elseif k == "numseqs" then elseif k == "numseqs" then
return bridged.get_numseqs() return math.tointeger(bridged.get_numseqs()), true
elseif bridged.has_setting(k) then elseif bridged.has_setting(k) then
return bridged.get_setting(k), true return bridged.get_setting(k), true
else else
@ -589,20 +607,23 @@ return function(_python, _bridged)
---@param t KoboldSettings_base ---@param t KoboldSettings_base
function KoboldSettings_mt.__newindex(t, k, v) function KoboldSettings_mt.__newindex(t, k, v)
if k == "gen_len" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then if k == "genamt" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
bridged.set_genamt(v)
maybe_require_regeneration() maybe_require_regeneration()
bridged.set_gen_len(v) koboldbridge.resend_settings_required = true
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
if koboldbridge.userstate == "genmod" then if koboldbridge.userstate == "genmod" then
error("Cannot set numseqs from a generation modifier") error("Cannot set numseqs from a generation modifier")
return return
end end
bridged.set_numseqs(v) bridged.set_numseqs(v)
koboldbridge.resend_settings_required = true
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
maybe_require_regeneration() maybe_require_regeneration()
end end
return bridged.set_setting(k, v) bridged.set_setting(k, v)
koboldbridge.resend_settings_required = true
end end
return t return t
end end
@ -1125,6 +1146,9 @@ return function(_python, _bridged)
if koboldbridge.outmod ~= nil then if koboldbridge.outmod ~= nil then
r = koboldbridge.outmod() r = koboldbridge.outmod()
end end
if koboldbridge.resend_settings_required then
bridged.resend_settings()
end
koboldbridge.generating = true koboldbridge.generating = true
koboldbridge.userstate = "inmod" koboldbridge.userstate = "inmod"
return r return r
@ -1140,6 +1164,7 @@ return function(_python, _bridged)
setmetatable(KoboldWorldInfoFolder, KoboldWorldInfoFolder_mt) setmetatable(KoboldWorldInfoFolder, KoboldWorldInfoFolder_mt)
setmetatable(KoboldWorldInfoFolderSelector, KoboldWorldInfoFolderSelector_mt) setmetatable(KoboldWorldInfoFolderSelector, KoboldWorldInfoFolderSelector_mt)
setmetatable(KoboldWorldInfo, KoboldWorldInfo_mt) setmetatable(KoboldWorldInfo, KoboldWorldInfo_mt)
setmetatable(KoboldSettings, KoboldSettings_mt)
setmetatable(KoboldUserScriptModule, KoboldUserScriptModule_mt) setmetatable(KoboldUserScriptModule, KoboldUserScriptModule_mt)
setmetatable(KoboldUserScriptList, KoboldUserScriptList_mt) setmetatable(KoboldUserScriptList, KoboldUserScriptList_mt)
setmetatable(kobold, KoboldLib_mt) setmetatable(kobold, KoboldLib_mt)