From 63bb76b073535948eb69e7c59dc3bd8d881c6fa2 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 14:13:36 -0500
Subject: [PATCH 1/9] Make sure `vars.wifolders_u` is set up properly on
loading a save
---
aiserver.py | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/aiserver.py b/aiserver.py
index 80de124e..2bd43be0 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -3867,6 +3867,8 @@ def loadRequest(loadpath, filename=None):
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"] is not None):
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
num += 1
for uid in vars.wifolders_l + [None]:
@@ -3877,6 +3879,8 @@ def loadRequest(loadpath, filename=None):
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"] is not None):
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
stablesortwi()
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
@@ -4072,6 +4076,8 @@ def importgame():
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"]) is not None:
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
num += 1
for uid in vars.wifolders_l + [None]:
@@ -4082,6 +4088,8 @@ def importgame():
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"] is not None):
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
stablesortwi()
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
@@ -4151,6 +4159,8 @@ def importAidgRequest(id):
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"]) is not None:
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
num += 1
for uid in vars.wifolders_l + [None]:
@@ -4161,6 +4171,8 @@ def importAidgRequest(id):
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"] is not None):
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
stablesortwi()
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
@@ -4210,6 +4222,8 @@ def wiimportrequest():
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"]) is not None:
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
num += 1
for uid in [None]:
vars.worldinfo.append({"key": "", "keysecondary": "", "content": "", "comment": "", "folder": uid, "num": None, "init": False, "selective": False, "constant": False, "uid": None})
@@ -4219,6 +4233,8 @@ def wiimportrequest():
break
vars.worldinfo_u[uid] = vars.worldinfo[-1]
vars.worldinfo[-1]["uid"] = uid
+ if(vars.worldinfo[-1]["folder"] is not None):
+ vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
print("{0}".format(vars.worldinfo[0]))
From f46ebd235953fce7ed4a5c4dcb90dc8c7f3af36d Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 14:18:58 -0500
Subject: [PATCH 2/9] Always pass 1.1 as repetition penalty to generator
The `dynamic_processor_wrap` makes it so that the repetition penalty is
read directly from `vars`, but this only works if the initial repetition
sent to `generator` is not equal to 1. So we are now forcing the initial
repetition penalty to be something other than 1.
---
aiserver.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/aiserver.py b/aiserver.py
index 2bd43be0..cdb7a661 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -2560,7 +2560,7 @@ def _generate(txt, minimum, maximum, found_entries):
do_sample=True,
min_length=minimum,
max_length=int(2e9),
- repetition_penalty=vars.rep_pen,
+ repetition_penalty=1.1,
bad_words_ids=vars.badwordsids,
use_cache=True,
num_return_sequences=numseqs
From e20452ddd880330bb257d395d7f8721432972486 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 14:40:10 -0500
Subject: [PATCH 3/9] Retrying random story generation now also remembers
memory
---
aiserver.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/aiserver.py b/aiserver.py
index cdb7a661..3eefa16d 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -163,6 +163,7 @@ class vars:
genseqs = [] # Temporary storage for generated sequences
recentback = False # Whether Back button was recently used without Submitting or Retrying after
recentrng = None # If a new random game was recently generated without Submitting after, this is the topic used (as a string), otherwise this is None
+ recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
useprompt = False # Whether to send the full prompt with every submit action
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently)
@@ -1645,7 +1646,7 @@ def get_message(msg):
vars.chatname = msg['chatname']
settingschanged()
emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname}, broadcast=True)
- vars.recentrng = None
+ vars.recentrng = vars.recentrngm = None
actionsubmit(msg['data'], actionmode=msg['actionmode'])
elif(vars.mode == "edit"):
editsubmit(msg['data'])
@@ -2130,7 +2131,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
set_aibusy(1)
if(disable_recentrng):
- vars.recentrng = None
+ vars.recentrng = vars.recentrngm = None
vars.recentback = False
vars.recentedit = False
@@ -2272,7 +2273,7 @@ def actionretry(data):
if(vars.aibusy):
return
if(vars.recentrng is not None):
- randomGameRequest(vars.recentrng)
+ randomGameRequest(vars.recentrng, memory=vars.recentrngm)
return
# Remove last action if possible and resubmit
if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0):
@@ -4282,6 +4283,7 @@ def randomGameRequest(topic, memory=""):
newGameRequest()
return
vars.recentrng = topic
+ vars.recentrngm = memory
newGameRequest()
_memory = memory
if(len(memory) > 0):
From 2fc0bdfcba5028729a81bf5155ee1418adb19a85 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 14:41:31 -0500
Subject: [PATCH 4/9] Correct a typo in `restorePrompt()`
---
static/application.js | 2 +-
templates/index.html | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/static/application.js b/static/application.js
index 4d9ab676..f97333d6 100644
--- a/static/application.js
+++ b/static/application.js
@@ -1511,7 +1511,7 @@ function restorePrompt() {
if(shadow_text.length && shadow_text[0].firstChild && (shadow_text[0].firstChild.nodeType === 3 || shadow_text[0].firstChild.tagName === "BR")) {
detected = true;
ref = shadow_text;
- } else if(game_text.length && game_text[0].firstChild && game_text[0].firstChild.nodeType === 3 || game_text[0].firstChild.tagName === "BR") {
+ } else if(game_text.length && game_text[0].firstChild && (game_text[0].firstChild.nodeType === 3 || game_text[0].firstChild.tagName === "BR")) {
detected = true;
ref = game_text;
}
diff --git a/templates/index.html b/templates/index.html
index 41a81293..08357979 100644
--- a/templates/index.html
+++ b/templates/index.html
@@ -17,7 +17,7 @@
-
+
From aa86c6001c52746445575303b4298a48ea623591 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 14:43:37 -0500
Subject: [PATCH 5/9] `--breakmodel_gpublocks` should handle -1 properly now
---
aiserver.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/aiserver.py b/aiserver.py
index 3eefa16d..8843b030 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -272,6 +272,13 @@ def device_config(model):
try:
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count()
+ s = n_layers
+ for i in range(len(breakmodel.gpu_blocks)):
+ if(breakmodel.gpu_blocks[i] <= -1):
+ breakmodel.gpu_blocks[i] = s
+ break
+ else:
+ s -= breakmodel.gpu_blocks[i]
assert sum(breakmodel.gpu_blocks) <= n_layers
n_layers -= sum(breakmodel.gpu_blocks)
except:
From 6edc6387f4d583e13879372931aaa4f41a254820 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 17:11:14 -0500
Subject: [PATCH 6/9] Accept command line arguments in `KOBOLDAI_ARGS`
environment var
So that you can use gunicorn or whatever with command-line arguments by
passing the arguments in an environment variable.
---
aiserver.py | 17 ++++++++++++-----
1 file changed, 12 insertions(+), 5 deletions(-)
diff --git a/aiserver.py b/aiserver.py
index 8843b030..67d150e0 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -15,8 +15,6 @@ from eventlet import tpool
from os import path, getcwd
import re
-import tkinter as tk
-from tkinter import messagebox
import json
import collections
import zipfile
@@ -388,7 +386,12 @@ parser.add_argument("--override_delete", action='store_true', help="Deleting sto
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
-args = parser.parse_args()
+args: argparse.Namespace = None
+if(os.environ.get("KOBOLDAI_ARGS") is not None):
+ import shlex
+ args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"]))
+else:
+ args = parser.parse_args()
vars.model = args.model;
if args.remote:
@@ -4308,6 +4311,7 @@ loadsettings()
# Final startup commands to launch Flask app
#==================================================================#
if __name__ == "__main__":
+ print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END))
# Start Flask/SocketIO (Blocking, so this must be last method!)
@@ -4321,12 +4325,15 @@ if __name__ == "__main__":
cloudflare = _run_cloudflared(5000)
with open('cloudflare.log', 'w') as cloudflarelog:
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
- print("\n" + format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link : " + cloudflare + format(colors.END))
+ print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link : " + cloudflare + format(colors.END))
vars.serverstarted = True
socketio.run(app, host='0.0.0.0', port=5000)
else:
import webbrowser
webbrowser.open_new('http://localhost:5000')
- print("{0}\nServer started!\nYou may now connect with a browser at http://127.0.0.1:5000/{1}".format(colors.GREEN, colors.END))
+ print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:5000/{1}".format(colors.GREEN, colors.END))
vars.serverstarted = True
socketio.run(app, port=5000)
+
+else:
+ print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END))
From fbf506207419df9b63884131af0c462ac0f90b91 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 19:26:59 -0500
Subject: [PATCH 7/9] Add option to `compute_context()` to not scan story
---
aiserver.py | 26 +++++++++++++++++-----
bridge.lua | 10 +++++----
userscripts/kaipreset_location_scanner.lua | 4 ++--
3 files changed, 28 insertions(+), 12 deletions(-)
diff --git a/aiserver.py b/aiserver.py
index 67d150e0..1e5556a9 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -1152,8 +1152,10 @@ def lua_encode(string):
# Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs
#==================================================================#
-def lua_compute_context(submission, entries, folders):
+def lua_compute_context(submission, entries, folders, kwargs):
assert type(submission) is str
+ if(kwargs is None):
+ kwargs = vars.lua_state.table()
actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions
allowed_entries = None
allowed_folders = None
@@ -1169,8 +1171,20 @@ def lua_compute_context(submission, entries, folders):
while(folders[i] is not None):
allowed_folders.add(int(folders[i]))
i += 1
- winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
- txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
+ winfo, mem, anotetxt, _ = calcsubmitbudgetheader(
+ submission,
+ allowed_entries=allowed_entries,
+ allowed_folders=allowed_folders,
+ force_use_txt=True,
+ scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
+ )
+ txt, _, _ = calcsubmitbudget(
+ len(actions),
+ winfo,
+ mem,
+ anotetxt,
+ actions,
+ )
return tokenizer.decode(txt)
#==================================================================#
@@ -3370,7 +3384,7 @@ def deletewifolder(uid):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
-def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False):
+def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True):
original_txt = txt
# Dont go any further if WI is empty
@@ -3381,7 +3395,7 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx
ln = len(vars.actions)
# Don't bother calculating action history if widepth is 0
- if(vars.widepth > 0):
+ if(vars.widepth > 0 and scan_story):
depth = vars.widepth
# If this is not a continue, add 1 to widepth since submitted
# text is already in action history @ -1
@@ -3423,7 +3437,7 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx
found_entries.add(id(wi))
continue
- if(wi["key"] != ""):
+ if(len(wi["key"].strip()) > 0 and (not wi.get("selective", False) or len(wi.get("keysecondary", "").strip()) > 0)):
# Split comma-separated keys
keys = wi["key"].split(",")
keys_secondary = wi.get("keysecondary", "").split(",")
diff --git a/bridge.lua b/bridge.lua
index 460cbf8d..d5adbd01 100644
--- a/bridge.lua
+++ b/bridge.lua
@@ -383,8 +383,9 @@ return function(_python, _bridged)
end
---@param submission? string
+ ---@param kwargs? table
---@return string
- function KoboldWorldInfoEntry:compute_context(submission)
+ function KoboldWorldInfoEntry:compute_context(submission, kwargs)
if not check_validity(self) then
return ""
elseif submission == nil then
@@ -393,7 +394,7 @@ return function(_python, _bridged)
error("`compute_context` takes a string or nil as argument #1, but got a " .. type(submission))
return ""
end
- return bridged.compute_context(submission, {self.uid}, nil)
+ return bridged.compute_context(submission, {self.uid}, nil, kwargs)
end
---@generic K
@@ -484,8 +485,9 @@ return function(_python, _bridged)
---@param submission? string
---@param entries? KoboldWorldInfoEntry|table
+ ---@param kwargs? table
---@return string
- function KoboldWorldInfoFolder:compute_context(submission, entries)
+ function KoboldWorldInfoFolder:compute_context(submission, entries, kwargs)
if not check_validity(self) then
return ""
elseif submission == nil then
@@ -513,7 +515,7 @@ return function(_python, _bridged)
if self.name == "KoboldWorldInfoFolder" then
folders = {rawget(self, "_uid")}
end
- return bridged.compute_context(submission, _entries, folders)
+ return bridged.compute_context(submission, _entries, folders, kwargs)
end
---@return boolean
diff --git a/userscripts/kaipreset_location_scanner.lua b/userscripts/kaipreset_location_scanner.lua
index 4348e322..06f7e1ca 100644
--- a/userscripts/kaipreset_location_scanner.lua
+++ b/userscripts/kaipreset_location_scanner.lua
@@ -26,7 +26,7 @@ local example_config = [[;-- Location scanner
;-- Usage instructions:
;--
;-- 1. Create a world info folder with name containing the string
-;-- "<||ls||>" (without the double quotes). The comment can be anything as
+;-- "<||ls||>" (without the double quotes). The name can be anything as
;-- long as it contains that inside it somewhere -- for example, you could
;-- set the name to "Locations <||ls||>".
;--
@@ -124,7 +124,7 @@ function userscript.inmod()
key = e.key,
keysecondary = e.keysecondary,
}
- e.constant = e.constant or (not repeated and e:compute_context("") ~= e:compute_context(location))
+ e.constant = e.constant or (not repeated and e:compute_context("", {scan_story=false}) ~= e:compute_context(location, {scan_story=false}))
e.key = ""
e.keysecondary = ""
end
From fc6caa0df05a13016b6e69c80feacec78dddf9f2 Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 19:36:21 -0500
Subject: [PATCH 8/9] Easier method of adding kwargs to bridged in aiserver.py
---
aiserver.py | 66 ++++++++++++++++++++++++++++++-----------------------
1 file changed, 38 insertions(+), 28 deletions(-)
diff --git a/aiserver.py b/aiserver.py
index 1e5556a9..e902aa36 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -21,7 +21,7 @@ import zipfile
import packaging
import contextlib
import traceback
-from typing import Any, Union, Dict, Set, List
+from typing import Any, Callable, Union, Dict, Set, List
import requests
import html
@@ -1064,9 +1064,17 @@ if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
def lua_log_format_name(name):
return f"[{name}]" if type(name) is str else "CORE"
+_bridged = {}
+def bridged_kwarg(name=None):
+ def _bridged_kwarg(f: Callable):
+ _bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
+ return f
+ return _bridged_kwarg
+
#==================================================================#
# Event triggered when a userscript is loaded
#==================================================================#
+@bridged_kwarg()
def load_callback(filename, modulename):
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
@@ -1110,6 +1118,7 @@ def load_lua_scripts():
#==================================================================#
# Print message that originates from the userscript with the given name
#==================================================================#
+@bridged_kwarg()
def lua_print(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name
@@ -1119,6 +1128,7 @@ def lua_print(msg):
#==================================================================#
# Print warning that originates from the userscript with the given name
#==================================================================#
+@bridged_kwarg()
def lua_warn(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name
@@ -1128,6 +1138,7 @@ def lua_warn(msg):
#==================================================================#
# Decode tokens into a string using current tokenizer
#==================================================================#
+@bridged_kwarg()
def lua_decode(tokens):
tokens = list(tokens.values())
assert type(tokens) is list
@@ -1140,6 +1151,7 @@ def lua_decode(tokens):
#==================================================================#
# Encode string into list of token IDs using current tokenizer
#==================================================================#
+@bridged_kwarg()
def lua_encode(string):
assert type(string) is str
if("tokenizer" not in globals()):
@@ -1152,6 +1164,7 @@ def lua_encode(string):
# Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs
#==================================================================#
+@bridged_kwarg()
def lua_compute_context(submission, entries, folders, kwargs):
assert type(submission) is str
if(kwargs is None):
@@ -1190,6 +1203,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
#==================================================================#
# Get property of a world info entry given its UID and property name
#==================================================================#
+@bridged_kwarg()
def lua_get_attr(uid, k):
assert type(uid) is int and type(k) is str
if(uid in vars.worldinfo_u and k in (
@@ -1208,6 +1222,7 @@ def lua_get_attr(uid, k):
#==================================================================#
# Set property of a world info entry given its UID, property name and new value
#==================================================================#
+@bridged_kwarg()
def lua_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str
assert uid in vars.worldinfo_u and k in (
@@ -1227,6 +1242,7 @@ def lua_set_attr(uid, k, v):
#==================================================================#
# Get property of a world info folder given its UID and property name
#==================================================================#
+@bridged_kwarg()
def lua_folder_get_attr(uid, k):
assert type(uid) is int and type(k) is str
if(uid in vars.wifolders_d and k in (
@@ -1237,6 +1253,7 @@ def lua_folder_get_attr(uid, k):
#==================================================================#
# Set property of a world info folder given its UID, property name and new value
#==================================================================#
+@bridged_kwarg()
def lua_folder_set_attr(uid, k, v):
assert type(uid) is int and type(k) is str
assert uid in vars.wifolders_d and k in (
@@ -1251,12 +1268,14 @@ def lua_folder_set_attr(uid, k, v):
#==================================================================#
# Get the "Amount to Generate"
#==================================================================#
+@bridged_kwarg()
def lua_get_genamt():
return vars.genamt
#==================================================================#
# Set the "Amount to Generate"
#==================================================================#
+@bridged_kwarg()
def lua_set_genamt(genamt):
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
@@ -1265,12 +1284,14 @@ def lua_set_genamt(genamt):
#==================================================================#
# Get the "Gens Per Action"
#==================================================================#
+@bridged_kwarg()
def lua_get_numseqs():
return vars.numseqs
#==================================================================#
# Set the "Gens Per Action"
#==================================================================#
+@bridged_kwarg()
def lua_set_numseqs(numseqs):
assert type(numseqs) in (int, float) and numseqs >= 1
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
@@ -1279,6 +1300,7 @@ def lua_set_numseqs(numseqs):
#==================================================================#
# Check if a setting exists with the given name
#==================================================================#
+@bridged_kwarg()
def lua_has_setting(setting):
return setting in (
"anotedepth",
@@ -1326,6 +1348,7 @@ def lua_has_setting(setting):
#==================================================================#
# Return the setting with the given name if it exists
#==================================================================#
+@bridged_kwarg()
def lua_get_setting(setting):
if(setting in ("settemp", "temp")): return vars.temp
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
@@ -1350,6 +1373,7 @@ def lua_get_setting(setting):
#==================================================================#
# Set the setting with the given name if it exists
#==================================================================#
+@bridged_kwarg()
def lua_set_setting(setting, v):
actual_type = type(lua_get_setting(setting))
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
@@ -1380,12 +1404,14 @@ def lua_set_setting(setting, v):
#==================================================================#
# Get contents of memory
#==================================================================#
+@bridged_kwarg()
def lua_get_memory():
return vars.memory
#==================================================================#
# Set contents of memory
#==================================================================#
+@bridged_kwarg()
def lua_set_memory(m):
assert type(m) is str
vars.memory = m
@@ -1393,12 +1419,14 @@ def lua_set_memory(m):
#==================================================================#
# Get contents of author's note
#==================================================================#
+@bridged_kwarg()
def lua_get_authorsnote():
return vars.authornote
#==================================================================#
# Set contents of author's note
#==================================================================#
+@bridged_kwarg()
def lua_set_authorsnote(m):
assert type(m) is str
vars.authornote = m
@@ -1406,12 +1434,14 @@ def lua_set_authorsnote(m):
#==================================================================#
# Get contents of author's note template
#==================================================================#
+@bridged_kwarg()
def lua_get_authorsnotetemplate():
return vars.authornotetemplate
#==================================================================#
# Set contents of author's note template
#==================================================================#
+@bridged_kwarg()
def lua_set_authorsnotetemplate(m):
assert type(m) is str
vars.authornotetemplate = m
@@ -1419,6 +1449,7 @@ def lua_set_authorsnotetemplate(m):
#==================================================================#
# Save settings and send them to client
#==================================================================#
+@bridged_kwarg()
def lua_resend_settings():
settingschanged()
refresh_settings()
@@ -1426,6 +1457,7 @@ def lua_resend_settings():
#==================================================================#
# Set story chunk text and delete the chunk if the new chunk is empty
#==================================================================#
+@bridged_kwarg()
def lua_set_chunk(k, v):
assert type(k) in (int, None) and type(v) is str
assert k >= 0
@@ -1458,6 +1490,7 @@ def lua_set_chunk(k, v):
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
#==================================================================#
+@bridged_kwarg()
def lua_get_modeltype():
if(vars.noai):
return "readonly"
@@ -1486,6 +1519,7 @@ def lua_get_modeltype():
#==================================================================#
# Get model backend as "transformers" or "mtj"
#==================================================================#
+@bridged_kwarg()
def lua_get_modelbackend():
if(vars.noai):
return "readonly"
@@ -1498,6 +1532,7 @@ def lua_get_modelbackend():
#==================================================================#
# Check whether model is loaded from a custom path
#==================================================================#
+@bridged_kwarg()
def lua_is_custommodel():
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ")
@@ -1558,35 +1593,10 @@ bridged = {
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
- "load_callback": load_callback,
- "print": lua_print,
- "warn": lua_warn,
- "decode": lua_decode,
- "encode": lua_encode,
- "get_attr": lua_get_attr,
- "set_attr": lua_set_attr,
- "folder_get_attr": lua_folder_get_attr,
- "folder_set_attr": lua_folder_set_attr,
- "get_genamt": lua_get_genamt,
- "set_genamt": lua_set_genamt,
- "get_memory": lua_get_memory,
- "set_memory": lua_set_memory,
- "get_authorsnote": lua_get_authorsnote,
- "set_authorsnote": lua_set_authorsnote,
- "get_authorsnote": lua_get_authorsnotetemplate,
- "set_authorsnote": lua_set_authorsnotetemplate,
- "compute_context": lua_compute_context,
- "get_numseqs": lua_get_numseqs,
- "set_numseqs": lua_set_numseqs,
- "has_setting": lua_has_setting,
- "get_setting": lua_get_setting,
- "set_setting": lua_set_setting,
- "set_chunk": lua_set_chunk,
- "get_modeltype": lua_get_modeltype,
- "get_modelbackend": lua_get_modelbackend,
- "is_custommodel": lua_is_custommodel,
"vars": vars,
}
+for kwarg in _bridged:
+ bridged[kwarg] = _bridged[kwarg]
try:
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
vars.lua_state.globals().python,
From 01479c29eaf981e5585adb1e907e1d74fe0f573d Mon Sep 17 00:00:00 2001
From: Gnome Ann <>
Date: Tue, 4 Jan 2022 20:48:34 -0500
Subject: [PATCH 9/9] Fix the type hint for `bridged_kwarg` decorator
---
aiserver.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/aiserver.py b/aiserver.py
index e902aa36..a2f3daf3 100644
--- a/aiserver.py
+++ b/aiserver.py
@@ -21,7 +21,7 @@ import zipfile
import packaging
import contextlib
import traceback
-from typing import Any, Callable, Union, Dict, Set, List
+from typing import Any, Callable, TypeVar, Union, Dict, Set, List
import requests
import html
@@ -1065,8 +1065,9 @@ def lua_log_format_name(name):
return f"[{name}]" if type(name) is str else "CORE"
_bridged = {}
+F = TypeVar("F", bound=Callable)
def bridged_kwarg(name=None):
- def _bridged_kwarg(f: Callable):
+ def _bridged_kwarg(f: F):
_bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
return f
return _bridged_kwarg