Merge pull request #119 from VE-FORBRYDERNE/scripting-sp

Allow userscripts to change the soft prompt
This commit is contained in:
henk717 2022-04-14 21:33:20 +02:00 committed by GitHub
commit 372eb4c981
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 4 deletions

View File

@ -219,6 +219,7 @@ class vars:
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
userscripts = [] # List of userscripts to load
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
@ -699,13 +700,29 @@ def loadsettings():
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def check_for_sp_change():
while(True):
time.sleep(0.1)
if(vars.sp_changed):
with app.app_context():
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
vars.sp_changed = False
def spRequest(filename):
if(not vars.allowsp):
raise RuntimeError("Soft prompts are not supported by your current model/backend")
old_filename = vars.spfilename
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
if(old_filename != filename):
vars.sp_changed = True
return
global np
@ -713,7 +730,8 @@ def spRequest(filename):
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
if not isinstance(z, zipfile.ZipFile):
raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file")
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
@ -749,6 +767,8 @@ def spRequest(filename):
vars.spfilename = filename
settingschanged()
if(old_filename != filename):
vars.sp_changed = True
#==================================================================#
# Startup
@ -1069,6 +1089,7 @@ from flask_socketio import SocketIO, emit
app = Flask(__name__, root_path=os.getcwd())
app.config['SECRET KEY'] = 'secret!'
socketio = SocketIO(app, async_method="eventlet")
socketio.start_background_task(check_for_sp_change)
print("{0}OK!{1}".format(colors.GREEN, colors.END))
# Start transformers and create pipeline
@ -2199,6 +2220,29 @@ def lua_get_modelbackend():
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
#==================================================================#
# Return the filename (as a string) of the current soft prompt, or
# None if no soft prompt is loaded
#==================================================================#
@bridged_kwarg()
def lua_get_spfilename():
return vars.spfilename.strip() or None
#==================================================================#
# When called with a string as argument, sets the current soft prompt;
# when called with None as argument, uses no soft prompt.
# Returns True if soft prompt changed, False otherwise.
#==================================================================#
@bridged_kwarg()
def lua_set_spfilename(filename: Union[str, None]):
if(filename is None):
filename = ""
filename = str(filename).strip()
changed = lua_get_spfilename() != filename
assert all(q not in filename for q in ("/", "\\"))
spRequest(filename)
return changed
#==================================================================#
#
#==================================================================#
@ -2613,7 +2657,6 @@ def get_message(msg):
loadRequest(fileops.storypath(vars.loadselect))
elif(msg['cmd'] == 'sprequest'):
spRequest(vars.spselect)
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True)
elif(msg['cmd'] == 'deletestory'):
deletesave(msg['data'])
elif(msg['cmd'] == 'renamestory'):

View File

@ -165,7 +165,7 @@ return function(_python, _bridged)
---@field num_outputs integer
---@field feedback string
---@field is_config_file_open boolean
local kobold = setmetatable({API_VERSION = 1.0}, metawrapper)
local kobold = setmetatable({API_VERSION = 1.1}, metawrapper)
local KoboldLib_mt = setmetatable({}, metawrapper)
local KoboldLib_getters = setmetatable({}, metawrapper)
local KoboldLib_setters = setmetatable({}, metawrapper)
@ -1050,11 +1050,34 @@ return function(_python, _bridged)
return
elseif not bridged.vars.gamestarted and v == "" then
error("`KoboldLib.submission` must not be set to the empty string when the story is empty")
return
end
bridged.vars.submission = v
end
--==========================================================================
-- Userscript API: Soft prompt
--==========================================================================
---@param t KoboldLib
---@return string?
function KoboldLib_getters.spfilename(t)
return bridged.get_spfilename()
end
---@param t KoboldLib
---@param v string?
function KoboldLib_setters.spfilename(t, v)
if v:find("/") or v:find("\\") then
error("Cannot set `KoboldLib.spfilename` to a string that contains slashes")
end
if bridged.set_spfilename(v) then
maybe_require_regeneration()
end
end
--==========================================================================
-- Userscript API: Model information
--==========================================================================

View File

@ -122,7 +122,10 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile,
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
assert len(shape) == 2
except:
z.close()
try:
z.close()
except UnboundLocalError:
pass
return 1, None, None, None, None
if dtype not in ('V2', np.float16, np.float32):
z.close()

View File

@ -56,6 +56,7 @@
<li><a href="#kobold.num_outputs">kobold.num_outputs</a></li>
<li><a href="#kobold.outputs">kobold.outputs</a></li>
<li><a href="#kobold.settings">kobold.settings</a></li>
<li><a href="#kobold.spfilename">kobold.spfilename</a></li>
<li><a href="#kobold.story">kobold.story</a>
<ul>
<li></li>
@ -172,6 +173,7 @@
<li><code>kobold.num_outputs</code></li>
<li><code>kobold.outputs</code></li>
<li><code>kobold.settings</code></li>
<li><code>kobold.spfilename</code></li>
<li><code>kobold.story</code></li>
<li><code>kobold.submission</code></li>
<li><code>kobold.worldinfo</code></li>
@ -394,6 +396,14 @@
<li><code>kobold.settings.setwidepth</code> (World Info Depth)</li>
<li><code>kobold.settings.setuseprompt</code> (Always Use Prompt)</li>
</ul>
<h1 id="kobold.spfilename">kobold.spfilename</h1>
<p><em><strong>Readable from:</strong></em> anywhere<br>
<em><strong>Writable from:</strong></em> anywhere</p>
<pre class=" language-lua"><code class="prism language-lua">field kobold<span class="token punctuation">.</span>spfilename<span class="token punctuation">:</span> string?
</code></pre>
<p>The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is <code>nil</code> instead.</p>
<p>You can also set the soft prompt to use by setting this to a string or <code>nil</code>.</p>
<p>Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess.</p>
<h1 id="kobold.story">kobold.story</h1>
<p><em><strong>Readable from:</strong></em> anywhere<br>
<em><strong>Writable from:</strong></em> nowhere</p>

View File

@ -29,6 +29,7 @@ global kobold: KoboldLib
* `kobold.num_outputs`
* `kobold.outputs`
* `kobold.settings`
* `kobold.spfilename`
* `kobold.story`
* `kobold.submission`
* `kobold.worldinfo`
@ -372,6 +373,21 @@ Modifying certain fields from inside of a generation modifier triggers a regener
* `kobold.settings.setwidepth` (World Info Depth)
* `kobold.settings.setuseprompt` (Always Use Prompt)
# kobold.spfilename
***Readable from:*** anywhere
***Writable from:*** anywhere
```lua
field kobold.spfilename: string?
```
The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is `nil` instead.
You can also set the soft prompt to use by setting this to a string or `nil`.
Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess.
# kobold.story
***Readable from:*** anywhere