Add support for changing soft prompt from userscripts

This commit is contained in:
Gnome Ann 2022-04-12 15:59:05 -04:00
parent 9a2d346d60
commit a3a52dc9c3
3 changed files with 72 additions and 3 deletions

View File

@ -217,6 +217,7 @@ class vars:
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
userscripts = [] # List of userscripts to load
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
@ -697,13 +698,29 @@ def loadsettings():
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def check_for_sp_change():
while(True):
time.sleep(0.1)
if(vars.sp_changed):
with app.app_context():
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
vars.sp_changed = False
def spRequest(filename):
if(not vars.allowsp):
raise RuntimeError("Soft prompts are not supported by your current model/backend")
old_filename = vars.spfilename
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
if(old_filename != filename):
vars.sp_changed = True
return
global np
@ -711,7 +728,8 @@ def spRequest(filename):
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
if not isinstance(z, zipfile.ZipFile):
raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file")
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
@ -747,6 +765,8 @@ def spRequest(filename):
vars.spfilename = filename
settingschanged()
if(old_filename != filename):
vars.sp_changed = True
#==================================================================#
# Startup
@ -1067,6 +1087,7 @@ from flask_socketio import SocketIO, emit
app = Flask(__name__, root_path=os.getcwd())
app.config['SECRET KEY'] = 'secret!'
socketio = SocketIO(app, async_method="eventlet")
socketio.start_background_task(check_for_sp_change)
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
@ -2197,6 +2218,29 @@ def lua_get_modelbackend():
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
#==================================================================#
# Return the filename (as a string) of the current soft prompt, or
# None if no soft prompt is loaded
#==================================================================#
@bridged_kwarg()
def lua_get_spfilename():
return vars.spfilename.strip() or None
#==================================================================#
# When called with a string as argument, sets the current soft prompt;
# when called with None as argument, uses no soft prompt.
# Returns True if soft prompt changed, False otherwise.
#==================================================================#
@bridged_kwarg()
def lua_set_spfilename(filename: Union[str, None]):
if(filename is None):
filename = ""
filename = str(filename).strip()
changed = lua_get_spfilename() != filename
assert all(q not in filename for q in ("/", "\\"))
spRequest(filename)
return changed
#==================================================================#
#
#==================================================================#
@ -2611,7 +2655,6 @@ def get_message(msg):
loadRequest(fileops.storypath(vars.loadselect))
elif(msg['cmd'] == 'sprequest'):
spRequest(vars.spselect)
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True)
elif(msg['cmd'] == 'deletestory'):
deletesave(msg['data'])
elif(msg['cmd'] == 'renamestory'):

View File

@ -1050,11 +1050,34 @@ return function(_python, _bridged)
return
elseif not bridged.vars.gamestarted and v == "" then
error("`KoboldLib.submission` must not be set to the empty string when the story is empty")
return
end
bridged.vars.submission = v
end
--==========================================================================
-- Userscript API: Soft prompt
--==========================================================================
---@param t KoboldLib
---@return string?
function KoboldLib_getters.spfilename(t)
return bridged.get_spfilename()
end
---@param t KoboldLib
---@param v string?
function KoboldLib_setters.spfilename(t, v)
if v:find("/") or v:find("\\") then
error("Cannot set `KoboldLib.spfilename` to a string that contains slashes")
end
if bridged.set_spfilename(v) then
maybe_require_regeneration()
end
end
--==========================================================================
-- Userscript API: Model information
--==========================================================================

View File

@ -122,7 +122,10 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile,
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
assert len(shape) == 2
except:
z.close()
try:
z.close()
except UnboundLocalError:
pass
return 1, None, None, None, None
if dtype not in ('V2', np.float16, np.float32):
z.close()