mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-23 13:50:25 +01:00
Add support for changing soft prompt from userscripts
This commit is contained in:
parent
9a2d346d60
commit
a3a52dc9c3
47
aiserver.py
47
aiserver.py
@ -217,6 +217,7 @@ class vars:
|
||||
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
|
||||
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
|
||||
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
|
||||
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
|
||||
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
|
||||
userscripts = [] # List of userscripts to load
|
||||
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
||||
@ -697,13 +698,29 @@ def loadsettings():
|
||||
#==================================================================#
|
||||
# Load a soft prompt from a file
|
||||
#==================================================================#
|
||||
|
||||
def check_for_sp_change():
|
||||
while(True):
|
||||
time.sleep(0.1)
|
||||
if(vars.sp_changed):
|
||||
with app.app_context():
|
||||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
|
||||
vars.sp_changed = False
|
||||
|
||||
def spRequest(filename):
|
||||
if(not vars.allowsp):
|
||||
raise RuntimeError("Soft prompts are not supported by your current model/backend")
|
||||
|
||||
old_filename = vars.spfilename
|
||||
|
||||
vars.spfilename = ""
|
||||
settingschanged()
|
||||
|
||||
if(len(filename) == 0):
|
||||
vars.sp = None
|
||||
vars.sp_length = 0
|
||||
if(old_filename != filename):
|
||||
vars.sp_changed = True
|
||||
return
|
||||
|
||||
global np
|
||||
@ -711,7 +728,8 @@ def spRequest(filename):
|
||||
import numpy as np
|
||||
|
||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
if not isinstance(z, zipfile.ZipFile):
|
||||
raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file")
|
||||
with z.open('meta.json') as f:
|
||||
vars.spmeta = json.load(f)
|
||||
z.close()
|
||||
@ -747,6 +765,8 @@ def spRequest(filename):
|
||||
|
||||
vars.spfilename = filename
|
||||
settingschanged()
|
||||
if(old_filename != filename):
|
||||
vars.sp_changed = True
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
@ -1067,6 +1087,7 @@ from flask_socketio import SocketIO, emit
|
||||
app = Flask(__name__, root_path=os.getcwd())
|
||||
app.config['SECRET KEY'] = 'secret!'
|
||||
socketio = SocketIO(app, async_method="eventlet")
|
||||
socketio.start_background_task(check_for_sp_change)
|
||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
@ -2197,6 +2218,29 @@ def lua_get_modelbackend():
|
||||
def lua_is_custommodel():
|
||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||
|
||||
#==================================================================#
|
||||
# Return the filename (as a string) of the current soft prompt, or
|
||||
# None if no soft prompt is loaded
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_spfilename():
|
||||
return vars.spfilename.strip() or None
|
||||
|
||||
#==================================================================#
|
||||
# When called with a string as argument, sets the current soft prompt;
|
||||
# when called with None as argument, uses no soft prompt.
|
||||
# Returns True if soft prompt changed, False otherwise.
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_spfilename(filename: Union[str, None]):
|
||||
if(filename is None):
|
||||
filename = ""
|
||||
filename = str(filename).strip()
|
||||
changed = lua_get_spfilename() != filename
|
||||
assert all(q not in filename for q in ("/", "\\"))
|
||||
spRequest(filename)
|
||||
return changed
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
@ -2611,7 +2655,6 @@ def get_message(msg):
|
||||
loadRequest(fileops.storypath(vars.loadselect))
|
||||
elif(msg['cmd'] == 'sprequest'):
|
||||
spRequest(vars.spselect)
|
||||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True)
|
||||
elif(msg['cmd'] == 'deletestory'):
|
||||
deletesave(msg['data'])
|
||||
elif(msg['cmd'] == 'renamestory'):
|
||||
|
23
bridge.lua
23
bridge.lua
@ -1050,11 +1050,34 @@ return function(_python, _bridged)
|
||||
return
|
||||
elseif not bridged.vars.gamestarted and v == "" then
|
||||
error("`KoboldLib.submission` must not be set to the empty string when the story is empty")
|
||||
return
|
||||
end
|
||||
bridged.vars.submission = v
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Soft prompt
|
||||
--==========================================================================
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string?
|
||||
function KoboldLib_getters.spfilename(t)
|
||||
return bridged.get_spfilename()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string?
|
||||
function KoboldLib_setters.spfilename(t, v)
|
||||
if v:find("/") or v:find("\\") then
|
||||
error("Cannot set `KoboldLib.spfilename` to a string that contains slashes")
|
||||
end
|
||||
if bridged.set_spfilename(v) then
|
||||
maybe_require_regeneration()
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Model information
|
||||
--==========================================================================
|
||||
|
@ -122,7 +122,10 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile,
|
||||
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
|
||||
assert len(shape) == 2
|
||||
except:
|
||||
z.close()
|
||||
try:
|
||||
z.close()
|
||||
except UnboundLocalError:
|
||||
pass
|
||||
return 1, None, None, None, None
|
||||
if dtype not in ('V2', np.float16, np.float32):
|
||||
z.close()
|
||||
|
Loading…
Reference in New Issue
Block a user