Merge pull request #119 from VE-FORBRYDERNE/scripting-sp
Allow userscripts to change the soft prompt
This commit is contained in:
commit
372eb4c981
47
aiserver.py
47
aiserver.py
|
@ -219,6 +219,7 @@ class vars:
|
|||
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
|
||||
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
|
||||
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
|
||||
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
|
||||
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
|
||||
userscripts = [] # List of userscripts to load
|
||||
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
||||
|
@ -699,13 +700,29 @@ def loadsettings():
|
|||
#==================================================================#
|
||||
# Load a soft prompt from a file
|
||||
#==================================================================#
|
||||
|
||||
def check_for_sp_change():
|
||||
while(True):
|
||||
time.sleep(0.1)
|
||||
if(vars.sp_changed):
|
||||
with app.app_context():
|
||||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
|
||||
vars.sp_changed = False
|
||||
|
||||
def spRequest(filename):
|
||||
if(not vars.allowsp):
|
||||
raise RuntimeError("Soft prompts are not supported by your current model/backend")
|
||||
|
||||
old_filename = vars.spfilename
|
||||
|
||||
vars.spfilename = ""
|
||||
settingschanged()
|
||||
|
||||
if(len(filename) == 0):
|
||||
vars.sp = None
|
||||
vars.sp_length = 0
|
||||
if(old_filename != filename):
|
||||
vars.sp_changed = True
|
||||
return
|
||||
|
||||
global np
|
||||
|
@ -713,7 +730,8 @@ def spRequest(filename):
|
|||
import numpy as np
|
||||
|
||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
if not isinstance(z, zipfile.ZipFile):
|
||||
raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file")
|
||||
with z.open('meta.json') as f:
|
||||
vars.spmeta = json.load(f)
|
||||
z.close()
|
||||
|
@ -749,6 +767,8 @@ def spRequest(filename):
|
|||
|
||||
vars.spfilename = filename
|
||||
settingschanged()
|
||||
if(old_filename != filename):
|
||||
vars.sp_changed = True
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
|
@ -1069,6 +1089,7 @@ from flask_socketio import SocketIO, emit
|
|||
app = Flask(__name__, root_path=os.getcwd())
|
||||
app.config['SECRET KEY'] = 'secret!'
|
||||
socketio = SocketIO(app, async_method="eventlet")
|
||||
socketio.start_background_task(check_for_sp_change)
|
||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
|
@ -2199,6 +2220,29 @@ def lua_get_modelbackend():
|
|||
def lua_is_custommodel():
|
||||
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||
|
||||
#==================================================================#
|
||||
# Return the filename (as a string) of the current soft prompt, or
|
||||
# None if no soft prompt is loaded
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_get_spfilename():
|
||||
return vars.spfilename.strip() or None
|
||||
|
||||
#==================================================================#
|
||||
# When called with a string as argument, sets the current soft prompt;
|
||||
# when called with None as argument, uses no soft prompt.
|
||||
# Returns True if soft prompt changed, False otherwise.
|
||||
#==================================================================#
|
||||
@bridged_kwarg()
|
||||
def lua_set_spfilename(filename: Union[str, None]):
|
||||
if(filename is None):
|
||||
filename = ""
|
||||
filename = str(filename).strip()
|
||||
changed = lua_get_spfilename() != filename
|
||||
assert all(q not in filename for q in ("/", "\\"))
|
||||
spRequest(filename)
|
||||
return changed
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
|
@ -2613,7 +2657,6 @@ def get_message(msg):
|
|||
loadRequest(fileops.storypath(vars.loadselect))
|
||||
elif(msg['cmd'] == 'sprequest'):
|
||||
spRequest(vars.spselect)
|
||||
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True)
|
||||
elif(msg['cmd'] == 'deletestory'):
|
||||
deletesave(msg['data'])
|
||||
elif(msg['cmd'] == 'renamestory'):
|
||||
|
|
25
bridge.lua
25
bridge.lua
|
@ -165,7 +165,7 @@ return function(_python, _bridged)
|
|||
---@field num_outputs integer
|
||||
---@field feedback string
|
||||
---@field is_config_file_open boolean
|
||||
local kobold = setmetatable({API_VERSION = 1.0}, metawrapper)
|
||||
local kobold = setmetatable({API_VERSION = 1.1}, metawrapper)
|
||||
local KoboldLib_mt = setmetatable({}, metawrapper)
|
||||
local KoboldLib_getters = setmetatable({}, metawrapper)
|
||||
local KoboldLib_setters = setmetatable({}, metawrapper)
|
||||
|
@ -1050,11 +1050,34 @@ return function(_python, _bridged)
|
|||
return
|
||||
elseif not bridged.vars.gamestarted and v == "" then
|
||||
error("`KoboldLib.submission` must not be set to the empty string when the story is empty")
|
||||
return
|
||||
end
|
||||
bridged.vars.submission = v
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Soft prompt
|
||||
--==========================================================================
|
||||
|
||||
---@param t KoboldLib
|
||||
---@return string?
|
||||
function KoboldLib_getters.spfilename(t)
|
||||
return bridged.get_spfilename()
|
||||
end
|
||||
|
||||
---@param t KoboldLib
|
||||
---@param v string?
|
||||
function KoboldLib_setters.spfilename(t, v)
|
||||
if v:find("/") or v:find("\\") then
|
||||
error("Cannot set `KoboldLib.spfilename` to a string that contains slashes")
|
||||
end
|
||||
if bridged.set_spfilename(v) then
|
||||
maybe_require_regeneration()
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
--==========================================================================
|
||||
-- Userscript API: Model information
|
||||
--==========================================================================
|
||||
|
|
|
@ -122,7 +122,10 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile,
|
|||
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
|
||||
assert len(shape) == 2
|
||||
except:
|
||||
try:
|
||||
z.close()
|
||||
except UnboundLocalError:
|
||||
pass
|
||||
return 1, None, None, None, None
|
||||
if dtype not in ('V2', np.float16, np.float32):
|
||||
z.close()
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
<li><a href="#kobold.num_outputs">kobold.num_outputs</a></li>
|
||||
<li><a href="#kobold.outputs">kobold.outputs</a></li>
|
||||
<li><a href="#kobold.settings">kobold.settings</a></li>
|
||||
<li><a href="#kobold.spfilename">kobold.spfilename</a></li>
|
||||
<li><a href="#kobold.story">kobold.story</a>
|
||||
<ul>
|
||||
<li></li>
|
||||
|
@ -172,6 +173,7 @@
|
|||
<li><code>kobold.num_outputs</code></li>
|
||||
<li><code>kobold.outputs</code></li>
|
||||
<li><code>kobold.settings</code></li>
|
||||
<li><code>kobold.spfilename</code></li>
|
||||
<li><code>kobold.story</code></li>
|
||||
<li><code>kobold.submission</code></li>
|
||||
<li><code>kobold.worldinfo</code></li>
|
||||
|
@ -394,6 +396,14 @@
|
|||
<li><code>kobold.settings.setwidepth</code> (World Info Depth)</li>
|
||||
<li><code>kobold.settings.setuseprompt</code> (Always Use Prompt)</li>
|
||||
</ul>
|
||||
<h1 id="kobold.spfilename">kobold.spfilename</h1>
|
||||
<p><em><strong>Readable from:</strong></em> anywhere<br>
|
||||
<em><strong>Writable from:</strong></em> anywhere</p>
|
||||
<pre class=" language-lua"><code class="prism language-lua">field kobold<span class="token punctuation">.</span>spfilename<span class="token punctuation">:</span> string?
|
||||
</code></pre>
|
||||
<p>The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is <code>nil</code> instead.</p>
|
||||
<p>You can also set the soft prompt to use by setting this to a string or <code>nil</code>.</p>
|
||||
<p>Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess.</p>
|
||||
<h1 id="kobold.story">kobold.story</h1>
|
||||
<p><em><strong>Readable from:</strong></em> anywhere<br>
|
||||
<em><strong>Writable from:</strong></em> nowhere</p>
|
||||
|
|
|
@ -29,6 +29,7 @@ global kobold: KoboldLib
|
|||
* `kobold.num_outputs`
|
||||
* `kobold.outputs`
|
||||
* `kobold.settings`
|
||||
* `kobold.spfilename`
|
||||
* `kobold.story`
|
||||
* `kobold.submission`
|
||||
* `kobold.worldinfo`
|
||||
|
@ -372,6 +373,21 @@ Modifying certain fields from inside of a generation modifier triggers a regener
|
|||
* `kobold.settings.setwidepth` (World Info Depth)
|
||||
* `kobold.settings.setuseprompt` (Always Use Prompt)
|
||||
|
||||
# kobold.spfilename
|
||||
|
||||
***Readable from:*** anywhere
|
||||
***Writable from:*** anywhere
|
||||
|
||||
```lua
|
||||
field kobold.spfilename: string?
|
||||
```
|
||||
|
||||
The name of the soft prompt file to use (as a string), including the file extension. If not using a soft prompt, this is `nil` instead.
|
||||
|
||||
You can also set the soft prompt to use by setting this to a string or `nil`.
|
||||
|
||||
Modifying this field from inside of a generation modifier triggers a regeneration, which means that the context is recomputed after modification and generation begins again with the new context and previously generated tokens. This incurs a small performance penalty and should not be performed in excess.
|
||||
|
||||
# kobold.story
|
||||
|
||||
***Readable from:*** anywhere
|
||||
|
|
Loading…
Reference in New Issue