Add support for changing soft prompt from userscripts

This commit is contained in:
Gnome Ann
2022-04-12 15:59:05 -04:00
parent 9a2d346d60
commit a3a52dc9c3
3 changed files with 72 additions and 3 deletions

View File

@ -217,6 +217,7 @@ class vars:
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
userscripts = [] # List of userscripts to load
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
@ -697,13 +698,29 @@ def loadsettings():
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def check_for_sp_change():
while(True):
time.sleep(0.1)
if(vars.sp_changed):
with app.app_context():
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
vars.sp_changed = False
def spRequest(filename):
if(not vars.allowsp):
raise RuntimeError("Soft prompts are not supported by your current model/backend")
old_filename = vars.spfilename
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
if(old_filename != filename):
vars.sp_changed = True
return
global np
@ -711,7 +728,8 @@ def spRequest(filename):
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
if not isinstance(z, zipfile.ZipFile):
raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file")
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
@ -747,6 +765,8 @@ def spRequest(filename):
vars.spfilename = filename
settingschanged()
if(old_filename != filename):
vars.sp_changed = True
#==================================================================#
# Startup
@ -1067,6 +1087,7 @@ from flask_socketio import SocketIO, emit
app = Flask(__name__, root_path=os.getcwd())
app.config['SECRET KEY'] = 'secret!'
socketio = SocketIO(app, async_method="eventlet")
socketio.start_background_task(check_for_sp_change)
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
@ -2197,6 +2218,29 @@ def lua_get_modelbackend():
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
#==================================================================#
# Return the filename (as a string) of the current soft prompt, or
# None if no soft prompt is loaded
#==================================================================#
@bridged_kwarg()
def lua_get_spfilename():
return vars.spfilename.strip() or None
#==================================================================#
# When called with a string as argument, sets the current soft prompt;
# when called with None as argument, uses no soft prompt.
# Returns True if soft prompt changed, False otherwise.
#==================================================================#
@bridged_kwarg()
def lua_set_spfilename(filename: Union[str, None]):
if(filename is None):
filename = ""
filename = str(filename).strip()
changed = lua_get_spfilename() != filename
assert all(q not in filename for q in ("/", "\\"))
spRequest(filename)
return changed
#==================================================================#
#
#==================================================================#
@ -2611,7 +2655,6 @@ def get_message(msg):
loadRequest(fileops.storypath(vars.loadselect))
elif(msg['cmd'] == 'sprequest'):
spRequest(vars.spselect)
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True)
elif(msg['cmd'] == 'deletestory'):
deletesave(msg['data'])
elif(msg['cmd'] == 'renamestory'):