Merge pull request #47 from VE-FORBRYDERNE/scripting

Lua API fixes
This commit is contained in:
henk717 2021-12-20 04:32:25 +01:00 committed by GitHub
commit 7b56940ed7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 424 additions and 189 deletions

View File

@ -111,8 +111,9 @@ class vars:
lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua
lua_koboldcore = None # `koboldcore` from bridge.lua
lua_warper = None # Transformers logits warper controllable from Lua
lua_logname = ... # Name of previous userscript that logged to terminal
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
userscripts = [] # List of userscripts to load
corescript = "default.lua" # Filename of corescript to load
# badwords = [] # Array of str/chr values that should be removed from output
@ -615,7 +616,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -658,7 +659,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class LuaLogitsWarper(LogitsWarper):
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
@ -686,6 +687,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
assert scores.shape == scores_shape
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
def new_get_logits_warper(
top_k: int = None,
@ -703,18 +711,16 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
if(temp is not None and temp != 1.0):
warper_list.append(TemperatureLogitsWarper(temperature=temp))
vars.lua_warper = LuaLogitsWarper()
warper_list.append(vars.lua_warper)
return warper_list
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = new_get_logits_warper(
vars.top_k,
vars.top_p,
vars.tfs,
vars.temp,
1,
top_k=vars.top_k,
top_p=vars.top_p,
tfs=vars.tfs,
temp=vars.temp,
beams=1,
)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
@ -1042,7 +1048,7 @@ def lua_warn(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name
print(colors.BLUE + lua_log_format_name(vars.lua_logname) + ":" + colors.END, file=sys.stderr)
print(colors.RED + msg.replace("\033", "") + colors.END)
print(colors.YELLOW + msg.replace("\033", "") + colors.END)
#==================================================================#
# Decode tokens into a string using current tokenizer
@ -1062,12 +1068,36 @@ def lua_decode(tokens):
def lua_encode(string):
assert type(string) is str
if("tokenizer" not in globals()):
thinking = False
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================#
# Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs
#==================================================================#
def lua_compute_context(submission, entries, folders):
assert type(submission) is str
actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions
allowed_entries = None
allowed_folders = None
if(entries is not None):
allowed_entries = set()
i = 1
while(entries[i] is not None):
allowed_entries.add(int(entries[i]))
i += 1
if(folders is not None):
allowed_folders = set()
i = 1
while(folders[i] is not None):
allowed_folders.add(int(folders[i]))
i += 1
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
return txt
#==================================================================#
# Get property of a world info entry given its UID and property name
#==================================================================#
@ -1162,22 +1192,38 @@ def lua_set_numseqs(numseqs):
#==================================================================#
def lua_has_setting(setting):
return setting in (
"anotedepth",
"settemp",
"settopp",
"settopk",
"settfs",
"setreppen",
"settknmax",
"anotedepth",
"setwidepth",
"setuseprompt",
"setadventure",
"setdynamicscan",
"setnopromptgen",
"temp",
"topp",
"topk",
"tfs",
"reppen",
"tknmax",
"widepth",
"useprompt",
"adventure",
"dynamicscan",
"nopromptgen",
"frmttriminc",
"frmtrmblln",
"frmtrmspch",
"frmtadsnsp",
"frmtsingleline",
"triminc",
"rmblln",
"rmspch",
"adsnsp",
"singleline",
)
@ -1185,23 +1231,23 @@ def lua_has_setting(setting):
# Return the setting with the given name if it exists
#==================================================================#
def lua_get_setting(setting):
if(setting == "settemp"): return vars.temp
if(setting == "settopp"): return vars.top_p
if(setting == "settopk"): return vars.top_k
if(setting == "settfs"): return vars.tfs
if(setting == "setreppen"): return vars.rep_pen
if(setting == "settknmax"): return vars.max_length
if(setting in ("settemp", "temp")): return vars.temp
if(setting in ("settopp", "topp")): return vars.top_p
if(setting in ("settopk", "topk")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("settknmax", "tknmax")): return vars.max_length
if(setting == "anotedepth"): return vars.andepth
if(setting == "setwidepth"): return vars.widepth
if(setting == "setuseprompt"): return vars.useprompt
if(setting == "setadventure"): return vars.adventure
if(setting == "setdynamicscan"): return vars.dynamicscan
if(setting == "nopromptgen"): return vars.nopromptgen
if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"]
if(setting == "singleline"): return vars.formatoptns["frmttriminc"]
if(setting in ("setwidepth", "widepth")): return vars.widepth
if(setting in ("setuseprompt", "useprompt")): return vars.useprompt
if(setting in ("setadventure", "adventure")): return vars.adventure
if(setting in ("setdynamicscan", "dynamicscan")): return vars.dynamicscan
if(setting in ("setnopromptgen", "nopromptgen")): return vars.nopromptgen
if(setting in ("frmttriminc", "triminc")): return vars.formatoptns["frmttriminc"]
if(setting in ("frmtrmblln", "rmblln")): return vars.formatoptns["frmttrmblln"]
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
#==================================================================#
# Set the setting with the given name if it exists
@ -1211,25 +1257,25 @@ def lua_set_setting(setting, v):
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
v = actual_type(v)
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {setting} to {v}" + colors.END)
if(setting == "setadventure" and v):
if(setting in ("setadventure", "adventure") and v):
vars.actionmode = 1
if(setting == "settemp"): vars.temp = v
if(setting == "settopp"): vars.top_p = v
if(setting == "settopk"): vars.top_k = v
if(setting == "settfs"): vars.tfs = v
if(setting == "setreppen"): vars.rep_pen = v
if(setting == "settknmax"): vars.max_length = v
if(setting == "anotedepth"): vars.andepth = v
if(setting == "setwidepth"): vars.widepth = v
if(setting == "setuseprompt"): vars.useprompt = v
if(setting == "setadventure"): vars.adventure = v
if(setting == "setdynamicscan"): vars.dynamicscan = v
if(setting == "setnopromptgen"): vars.nopromptgen = v
if(setting == "frmttriminc"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtrmblln"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtrmspch"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtadsnsp"): vars.formatoptns["frmttriminc"] = v
if(setting == "singleline"): vars.formatoptns["frmttriminc"] = v
if(setting in ("settemp", "temp")): vars.temp = v
if(setting in ("settopp", "topp")): vars.top_p = v
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
if(setting == "anotedepth"): vars.andepth = v; return True
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
if(setting in ("setuseprompt", "useprompt")): vars.useprompt = v; return True
if(setting in ("setadventure", "adventure")): vars.adventure = v
if(setting in ("setdynamicscan", "dynamicscan")): vars.dynamicscan = v
if(setting in ("setnopromptgen", "nopromptgen")): vars.nopromptgen = v
if(setting in ("frmttriminc", "triminc")): vars.formatoptns["frmttriminc"] = v
if(setting in ("frmtrmblln", "rmblln")): vars.formatoptns["frmttrmblln"] = v
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
#==================================================================#
# Get contents of memory
@ -1244,6 +1290,19 @@ def lua_set_memory(m):
assert type(m) is str
vars.memory = m
#==================================================================#
# Get contents of author's note
#==================================================================#
def lua_get_authorsnote():
return vars.authornote
#==================================================================#
# Set contents of author's note
#==================================================================#
def lua_set_authorsnote(m):
assert type(m) is str
vars.authornote = m
#==================================================================#
# Save settings and send them to client
#==================================================================#
@ -1260,13 +1319,28 @@ def lua_set_chunk(k, v):
assert k != 0 or len(v) != 0
if(len(v) == 0):
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} deleted story chunk {k}" + colors.END)
inlinedelete(k)
chunk = int(k)
if(vars.lua_koboldbridge.userstate == "genmod"):
del vars._actions[chunk-1]
vars.lua_deleted.add(chunk)
if(vars._actions is not vars.actions):
del vars.actions[chunk-1]
else:
if(k == 0):
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited prompt chunk" + colors.END)
else:
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited story chunk {k}" + colors.END)
inlineedit(k, v)
chunk = int(k)
if(chunk == 0):
if(vars.lua_koboldbridge.userstate == "genmod"):
vars._prompt = v
vars.lua_edited.add(chunk)
vars.prompt = v
else:
if(vars.lua_koboldbridge.userstate == "genmod"):
vars._actions[chunk-1] = v
vars.lua_edited.add(chunk)
vars.actions[chunk-1] = v
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
@ -1288,8 +1362,8 @@ def lua_get_modeltype():
return "gpt2-xl"
if(vars.model == "NeoCustom" and hidden_size == 768):
return "gpt-neo-125M"
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
return "gpt-neo-1.3M"
if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
return "gpt-neo-1.3B"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
@ -1330,6 +1404,8 @@ def execute_inmod():
set_aibusy(0)
def execute_genmod():
vars.lua_edited = set()
vars.lua_deleted = set()
vars.lua_koboldbridge.execute_genmod()
def execute_outmod():
@ -1345,6 +1421,10 @@ def execute_outmod():
if(vars.lua_koboldbridge.resend_settings_required):
vars.lua_koboldbridge.resend_settings_required = False
lua_resend_settings()
for k in vars.lua_edited:
inlineedit(k, vars.actions[k])
for k in vars.lua_deleted:
inlinedelete(k)
#==================================================================#
# Lua runtime startup
@ -1373,6 +1453,9 @@ bridged = {
"set_genamt": lua_set_genamt,
"get_memory": lua_get_memory,
"set_memory": lua_set_memory,
"get_authorsnote": lua_get_authorsnote,
"set_authorsnote": lua_set_authorsnote,
"compute_context": lua_compute_context,
"get_numseqs": lua_get_numseqs,
"set_numseqs": lua_set_numseqs,
"has_setting": lua_has_setting,
@ -1455,6 +1538,7 @@ def get_message(msg):
# Submit action
if(msg['cmd'] == 'submit'):
if(vars.mode == "play"):
vars.lua_koboldbridge.feedback = None
actionsubmit(msg['data'], actionmode=msg['actionmode'])
elif(vars.mode == "edit"):
editsubmit(msg['data'])
@ -1873,83 +1957,51 @@ def actionsubmit(data, actionmode=0, force_submit=False):
# Ignore new submissions if the AI is currently busy
if(vars.aibusy):
return
set_aibusy(1)
vars.recentback = False
vars.recentedit = False
vars.actionmode = actionmode
while(True):
set_aibusy(1)
# "Action" mode
if(actionmode == 1):
data = data.strip().lstrip('>')
data = re.sub(r'\n+', ' ', data)
if(len(data)):
data = f"\n\n> {data}\n"
# If we're not continuing, store a copy of the raw input
if(data != ""):
vars.lastact = data
if(not vars.gamestarted):
vars.submission = data
execute_inmod()
data = vars.submission
if(not force_submit and len(data.strip()) == 0):
assert False
# Start the game
vars.gamestarted = True
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
# Save this first action as the prompt
vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
calcsubmit(data) # Run the first action through the generator
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
# Save this first action as the prompt
vars.prompt = data
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
refresh_story()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
# Apply input formatting & scripts before sending to tokenizer
if(vars.actionmode == 0):
data = applyinputformatting(data)
vars.submission = data
execute_inmod()
data = vars.submission
# Dont append submission if it's a blank/continue action
vars.recentback = False
vars.recentedit = False
vars.actionmode = actionmode
# "Action" mode
if(actionmode == 1):
data = data.strip().lstrip('>')
data = re.sub(r'\n+', ' ', data)
if(len(data)):
data = f"\n\n> {data}\n"
# If we're not continuing, store a copy of the raw input
if(data != ""):
# Store the result in the Action log
if(len(vars.prompt.strip()) == 0):
vars.lastact = data
if(not vars.gamestarted):
vars.submission = data
execute_inmod()
data = vars.submission
if(not force_submit and len(data.strip()) == 0):
assert False
# Start the game
vars.gamestarted = True
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
# Save this first action as the prompt
vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
calcsubmit(data) # Run the first action through the generator
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
data = ""
force_submit = True
continue
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
else:
vars.actions.append(data)
update_story_chunk('last')
if(not vars.noai and vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit(data)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
set_aibusy(0)
if(vars.lua_koboldbridge.regeneration_required):
# Save this first action as the prompt
vars.prompt = data
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
@ -1957,9 +2009,73 @@ def actionsubmit(data, actionmode=0, force_submit=False):
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
if(vars.lua_koboldbridge.restart_sequence is not None):
refresh_story()
data = ""
force_submit = True
continue
else:
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
refresh_story()
data = ""
force_submit = True
continue
genselect(genout)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
refresh_story()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
else:
# Apply input formatting & scripts before sending to tokenizer
if(vars.actionmode == 0):
data = applyinputformatting(data)
vars.submission = data
execute_inmod()
data = vars.submission
# Dont append submission if it's a blank/continue action
if(data != ""):
# Store the result in the Action log
if(len(vars.prompt.strip()) == 0):
vars.prompt = data
else:
vars.actions.append(data)
update_story_chunk('last')
if(not vars.noai and vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit(data)
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
data = ""
force_submit = True
continue
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
else:
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
if(vars.lua_koboldbridge.restart_sequence is not None):
data = ""
force_submit = True
continue
else:
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
data = ""
force_submit = True
continue
genselect(genout)
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
#==================================================================#
#
@ -1972,18 +2088,14 @@ def actionretry(data):
return
# Remove last action if possible and resubmit
if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0):
set_aibusy(1)
if(not vars.recentback and len(vars.actions) != 0 and len(vars.genseqs) == 0): # Don't pop if we're in the "Select sequence to keep" menu or if there are no non-prompt actions
last_key = vars.actions.get_last_key()
vars.actions.pop()
remove_story_chunk(last_key + 1)
vars.genseqs = []
vars.submission = ""
execute_inmod()
calcsubmit(vars.submission)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
vars.recentback = False
vars.recentedit = False
vars.lua_koboldbridge.feedback = None
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
elif(not vars.useprompt):
emit('from_server', {'cmd': 'errmsg', 'data': "Please enable \"Always Add Prompt\" to retry with your prompt."})
@ -2223,9 +2335,10 @@ def _generate(txt, minimum, maximum, found_entries):
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
actions = actions.copy()
vars._actions = vars._actions.copy()
with torch.no_grad():
already_generated = 0
@ -2235,7 +2348,7 @@ def _generate(txt, minimum, maximum, found_entries):
gen_in,
do_sample=True,
min_length=minimum,
max_length=maximum-already_generated,
max_length=int(2e9),
repetition_penalty=vars.rep_pen,
bad_words_ids=vars.badwordsids,
use_cache=True,
@ -2257,7 +2370,7 @@ def _generate(txt, minimum, maximum, found_entries):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
@ -2333,7 +2446,10 @@ def generate(txt, minimum, maximum, found_entries=None):
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect(genout)
# Clear CUDA cache again if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
@ -2351,6 +2467,11 @@ def genresult(genout):
# Format output before continuing
genout = applyoutputformatting(genout)
vars.lua_koboldbridge.feedback = genout
if(len(genout) == 0):
return
# Add formatted text to Actions array and refresh the game screen
if(len(vars.prompt.strip()) == 0):
@ -2383,12 +2504,17 @@ def genselect(genout):
def selectsequence(n):
if(len(vars.genseqs) == 0):
return
vars.actions.append(vars.genseqs[int(n)]["generated_text"])
update_story_chunk('last')
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True)
vars.lua_koboldbridge.feedback = vars.genseqs[int(n)]["generated_text"]
if(len(vars.lua_koboldbridge.feedback) != 0):
vars.actions.append(vars.lua_koboldbridge.feedback)
update_story_chunk('last')
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True)
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
vars.genseqs = []
if(vars.lua_koboldbridge.restart_sequence is not None):
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
#==================================================================#
# Send transformers-style request to ngrok/colab host
#==================================================================#
@ -2447,7 +2573,10 @@ def sendtocolab(txt, min, max):
seqs = []
for seq in genout:
seqs.append({"generated_text": seq})
genselect(seqs)
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect(genout)
# Format output before continuing
#genout = applyoutputformatting(getnewcontent(genout))
@ -2544,7 +2673,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect(genout)
set_aibusy(0)
@ -2995,7 +3127,7 @@ def deletewifolder(uid):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt, force_use_txt=False):
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty
@ -3038,6 +3170,11 @@ def checkworldinfo(txt, force_use_txt=False):
wimem = ""
found_entries = set()
for wi in vars.worldinfo:
if(allowed_entries is not None and wi["uid"] not in allowed_entries):
continue
if(allowed_folders is not None and wi["folder"] not in allowed_folders):
continue
if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
@ -3866,6 +4003,7 @@ def newGameRequest():
def randomGameRequest(topic):
newGameRequest()
vars.memory = "You generate the following " + topic + " story concept :"
vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True)
vars.memory = ""

View File

@ -207,12 +207,14 @@ return function(_python, _bridged)
koboldbridge.regeneration_required = false
koboldbridge.resend_settings_required = false
koboldbridge.generating = true
koboldbridge.restart_sequence = nil
koboldbridge.userstate = nil
koboldbridge.logits = {}
koboldbridge.vocab_size = 0
koboldbridge.generated = {}
koboldbridge.generated_cols = 0
koboldbridge.outputs = {}
koboldbridge.feedback = nil ---@type string|nil
---@return nil
local function maybe_require_regeneration()
@ -285,12 +287,18 @@ return function(_python, _bridged)
return _python.as_attrgetter(bridged.vars.worldinfo_u).get(rawget(self, "_uid")) ~= nil
end
---@param submission? string
---@return string
function KoboldWorldInfoEntry:compute_context()
function KoboldWorldInfoEntry:compute_context(submission)
if not check_validity(self) then
return ""
elseif submission == nil then
submission = kobold.submission
elseif type(submission) ~= "string" then
error("`compute_context` takes a string or nil as argument #1, but got a " .. type(submission))
return ""
end
return bridged.compute_context({self.uid})
return bridged.compute_context(submission, {self.uid}, nil)
end
---@generic K
@ -335,7 +343,7 @@ return function(_python, _bridged)
error("`"..rawget(t, "_name").."."..k.."` must be a "..KoboldWorldInfoEntry_fieldtypes[k].."; you attempted to set it to a "..type(v))
return
else
if k ~= "comment" then
if k ~= "comment" and not (t.selective and k == "keysecondary") then
maybe_require_regeneration()
end
bridged.set_attr(t.uid, k, v)
@ -379,26 +387,38 @@ return function(_python, _bridged)
return entry
end
---@param submission? string
---@param entries? KoboldWorldInfoEntry|table<any, KoboldWorldInfoEntry>
---@return string
function KoboldWorldInfoFolder:compute_context(entries)
function KoboldWorldInfoFolder:compute_context(submission, entries)
if not check_validity(self) then
return
end
if entries ~= nil and type(entries) ~= "table" or (entries.name ~= nil and entries.name ~= "KoboldWorldInfoEntry") then
error("`compute_context` takes a KoboldWorldInfoEntry, table of KoboldWorldInfoEntries or nil as argument, but got a " .. type(entries))
return ""
elseif submission == nil then
submission = kobold.submission
elseif type(submission) ~= "string" then
error("`compute_context` takes a string or nil as argument #1, but got a " .. type(submission))
return ""
end
if entries.name == "KoboldWorldInfoEntry" then
entries = {entries}
end
local _entries
for k, v in pairs(entries) do
if type(v) == "table" and v.name == "KoboldWorldInfoEntry" and (rawget(self, "_name") ~= "KoboldWorldInfoFolder" or self.uid == v.uid) and v:is_valid() then
_entries[k] = v.uid
if entries ~= nil then
if type(entries) ~= "table" or (entries.name ~= nil and entries.name ~= "KoboldWorldInfoEntry") then
error("`compute_context` takes a KoboldWorldInfoEntry, table of KoboldWorldInfoEntries or nil as argument #2, but got a " .. type(entries))
return ""
elseif entries.name == "KoboldWorldInfoEntry" then
_entries = {entries}
else
for k, v in pairs(entries) do
if type(v) == "table" and v.name == "KoboldWorldInfoEntry" and v:is_valid() then
_entries[k] = v.uid
end
end
end
end
return bridged.compute_context(_entries)
local folders
if self.name == "KoboldWorldInfoFolder" then
folders = {rawget(self, "_uid")}
end
return bridged.compute_context(submission, _entries, folders)
end
---@return boolean
@ -607,10 +627,12 @@ return function(_python, _bridged)
if k == "content" then
if rawget(t, "_num") == 0 then
if bridged.vars.gamestarted then
return bridged.vars.prompt
local prompt = koboldbridge.userstate == "genmod" and bridged.vars._prompt or bridged.vars.prompt
return prompt
end
end
return _python.as_attrgetter(bridged.vars.actions).get(math.tointeger(rawget(t, "_num")) - 1)
local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions
return _python.as_attrgetter(actions).get(math.tointeger(rawget(t, "_num")) - 1)
end
end
@ -631,7 +653,8 @@ return function(_python, _bridged)
error("Attempted to set the prompt chunk's content to the empty string; this is not allowed")
return
end
if _k ~= 0 and _python.as_attrgetter(bridged.vars.actions).get(_k-1) == nil then
local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions
if _k ~= 0 and _python.as_attrgetter(actions).get(_k-1) == nil then
return
end
bridged.set_chunk(_k, v)
@ -655,7 +678,8 @@ return function(_python, _bridged)
---@return fun(): KoboldStoryChunk, table, nil
function KoboldStory:forward_iter()
local nxt, iterator = _python.iter(bridged.vars.actions)
local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions
local nxt, iterator = _python.iter(actions)
local run_once = false
local f = function()
if not bridged.vars.gamestarted then
@ -682,7 +706,8 @@ return function(_python, _bridged)
---@return fun(): KoboldStoryChunk, table, nil
function KoboldStory:reverse_iter()
local nxt, iterator = _python.iter(_python.builtins.reversed(bridged.vars.actions))
local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions
local nxt, iterator = _python.iter(_python.builtins.reversed(actions))
local last_run = false
local f = function()
if not bridged.vars.gamestarted or last_run then
@ -738,21 +763,38 @@ return function(_python, _bridged)
---@class KoboldSettings : KoboldSettings_base
---@field numseqs integer
---@field genamt integer
---@field anotedepth integer
---@field settemp number
---@field settopp number
---@field settopk integer
---@field settfs number
---@field setreppen number
---@field settknmax integer
---@field anotedepth integer
---@field setwidepth integer
---@field setuseprompt boolean
---@field setadventure boolean
---@field setdynamicscan boolean
---@field setnopromptgen boolean
---@field temp number
---@field topp number
---@field topk integer
---@field tfs number
---@field reppen number
---@field tknmax integer
---@field widepth integer
---@field useprompt boolean
---@field adventure boolean
---@field dynamicscan boolean
---@field nopromptgen boolean
---@field frmttriminc boolean
---@field frmtrmblln boolean
---@field frmtrmspch boolean
---@field frmtadsnsp boolean
---@field frmtsingleline boolean
---@field triminc boolean
---@field rmblln boolean
---@field rmspch boolean
---@field adsnsp boolean
---@field singleline boolean
local KoboldSettings = setmetatable({
_name = "KoboldSettings",
@ -783,9 +825,9 @@ return function(_python, _bridged)
if type(k) ~= "string" then
return
end
if k == "genamt" then
if k == "genamt" or k == "output" or k == "setoutput" then
return math.tointeger(bridged.get_genamt()), true
elseif k == "numseqs" then
elseif k == "numseqs" or k == "numseq" or k == "setnumseq" then
return math.tointeger(bridged.get_numseqs()), true
elseif bridged.has_setting(k) then
return bridged.get_setting(k), true
@ -796,11 +838,10 @@ return function(_python, _bridged)
---@param t KoboldSettings_base
function KoboldSettings_mt.__newindex(t, k, v)
if k == "genamt" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
if (k == "genamt" or k == "output" or k == "setoutput") and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then
bridged.set_genamt(v)
maybe_require_regeneration()
koboldbridge.resend_settings_required = true
elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
elseif (k == "numseqs" or k == "numseq" or k == "setnumseq") and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then
if koboldbridge.userstate == "genmod" then
error("Cannot set numseqs from a generation modifier")
return
@ -808,10 +849,9 @@ return function(_python, _bridged)
bridged.set_numseqs(v)
koboldbridge.resend_settings_required = true
elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then
if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then
if bridged.set_setting(k, v) == true then
maybe_require_regeneration()
end
bridged.set_setting(k, v)
koboldbridge.resend_settings_required = true
end
return t
@ -821,7 +861,7 @@ return function(_python, _bridged)
--==========================================================================
-- Userscript API: Memory
-- Userscript API: Memory / Author's Note
--==========================================================================
---@param t KoboldLib
@ -842,6 +882,24 @@ return function(_python, _bridged)
bridged.set_memory(v)
end
---@param t KoboldLib
---@return string
function KoboldLib_getters.authorsnote(t)
return bridged.get_authorsnote()
end
---@param t KoboldLib
---@param v string
---@return KoboldLib
function KoboldLib_setters.authorsnote(t, v)
if type(v) ~= "string" then
error("`KoboldLib.authorsnote` must be a string; you attempted to set it to a "..type(v))
return
end
maybe_require_regeneration()
bridged.set_authorsnote(v)
end
--==========================================================================
-- Userscript API: User-submitted text (after applying input formatting)
@ -992,7 +1050,7 @@ return function(_python, _bridged)
error("Cannot write to `KoboldLib.logits` from outside of a generation modifer")
return
elseif type(v) ~= "table" then
error("`KoboldLib.logits` must be a 2D list (table) of numbers; you attempted to set it to a " .. type(v))
error("`KoboldLib.logits` must be a 2D array of numbers; you attempted to set it to a " .. type(v))
return
end
koboldbridge.logits = v
@ -1009,6 +1067,8 @@ return function(_python, _bridged)
local backend = kobold.modelbackend
if backend == "readonly" or backend == "api" then
return 0
elseif koboldbridge.userstate == "outmod" then
return koboldbridge.num_outputs
end
return kobold.settings.numseqs
end
@ -1058,7 +1118,7 @@ return function(_python, _bridged)
error("Cannot write to `KoboldLib.generated` from outside of a generation modifier")
return
elseif type(v) ~= "table" then
error("`KoboldLib.generated` must be a 2D list (table) of integers; you attempted to set it to a " .. type(v))
error("`KoboldLib.generated` must be a 2D array of integers; you attempted to set it to a " .. type(v))
return
end
koboldbridge.generated = v
@ -1072,14 +1132,13 @@ return function(_python, _bridged)
---@param t KoboldLib
---@return integer
function KoboldLib_getters.num_outputs(t)
local backend = kobold.modelbackend
if backend == "readonly" then
return 0
end
local model = kobold.model
if model == "OAI" or model == "InferKit" then
return 1
end
if koboldbridge.userstate == "outmod" then
return koboldbridge.num_outputs
end
return kobold.settings.numseqs
end
@ -1105,7 +1164,7 @@ return function(_python, _bridged)
error("Cannot write to `KoboldLib.generated` from outside of an output modifier")
return
elseif type(v) ~= "table" then
error("`KoboldLib.generated` must be a list (table) of strings; you attempted to set it to a " .. type(v))
error("`KoboldLib.generated` must be a 1D array of strings; you attempted to set it to a " .. type(v))
return
end
koboldbridge.outputs = v
@ -1158,6 +1217,41 @@ return function(_python, _bridged)
koboldbridge.generating = false
end
---@param sequence? integer
---@return nil
function kobold.restart_generation(sequence)
if sequence == nil then
sequence = 0
end
sequence_type = type(sequence)
sequence = math.tointeger(sequence)
if sequence_type ~= "number" then
error("`kobold.restart_generation` takes an integer greater than or equal to 0 or nil as argument, but got a " .. sequence_type)
return
elseif sequence < 0 then
error("`kobold.restart_generation` takes an integer greater than or equal to 0 or nil as argument, but got `" .. sequence .. "`")
return
end
if koboldbridge.userstate ~= "outmod" then
error("Can only call `kobold.restart_generation()` from an output modifier")
return
end
koboldbridge.restart_sequence = sequence
end
---@param t KoboldCoreLib
---@return string
function KoboldLib_getters.feedback(t)
return koboldbridge.feedback
end
---@param t KoboldCoreLib
---@param v string
---@return KoboldCoreLib
function KoboldLib_setters.feedback(t, v)
error("`KoboldLib.feedback` is a read-only attribute")
end
--==========================================================================
-- Core script API
@ -1383,17 +1477,17 @@ return function(_python, _bridged)
end
local function redirected_print(...)
local args = {...}
for k, v in ipairs(args) do
args[k] = tostring(v)
local args = table.pack(...)
for i = 1, args.n do
args[i] = tostring(args[i])
end
bridged.print(table.concat(args, "\t"))
end
local function redirected_warn(...)
local args = {...}
for k, v in ipairs(args) do
args[k] = tostring(v)
local args = table.pack(...)
for i = 1, args.n do
args[i] = tostring(args[i])
end
bridged.warn(table.concat(args, "\t"))
end
@ -1579,8 +1673,8 @@ return function(_python, _bridged)
koboldbridge.num_userscripts = 0
for i, filename in _python.enumerate(filenames) do
bridged.load_callback(filename, modulenames[i])
---@type KoboldUserScript
koboldbridge.logging_name = modulenames[i]
---@type KoboldUserScript
local _userscript = old_loadfile(join_folder_and_filename(bridged.userscript_path, filename), "t", koboldbridge.get_universe(filename))()
koboldbridge.logging_name = nil
local userscript = deepcopy(KoboldUserScriptModule)
@ -1606,9 +1700,10 @@ return function(_python, _bridged)
function koboldbridge.execute_inmod()
local r
koboldbridge.restart_sequence = nil
koboldbridge.userstate = "inmod"
koboldbridge.regeneration_required = false
koboldbridge.generating = true
koboldbridge.userstate = "inmod"
koboldbridge.generated_cols = 0
koboldbridge.generated = {}
for i = 1, kobold.settings.numseqs do
@ -1665,13 +1760,14 @@ return function(_python, _bridged)
local r
koboldbridge.generating = false
koboldbridge.userstate = "outmod"
koboldbridge.num_outputs = kobold.settings.numseqs
if koboldbridge.outmod ~= nil then
local _outputs = deepcopy(koboldbridge.outputs)
r = koboldbridge.outmod()
setmetatable(koboldbridge.outputs, nil)
for k, v in old_next, koboldbridge.outputs, nil do
if type(v) ~= "string" then
error("`kobold.outputs` must be a 1D list of strings, but found a non-string element at index " .. k)
error("`kobold.outputs` must be a 1D array of strings, but found a non-string element at index " .. k)
return r
end
if v ~= _outputs[k] then

View File

@ -1,6 +1,6 @@
-- Default core script
-- Runs all input modifiers and generation modifiers in forward order, and
-- runs all output modifiers in reverse order
-- Runs all generation modifiers and output modifiers in forward order, and
-- runs all input modifiers in reverse order
kobold, koboldcore = require("bridge")() -- This line is optional and is only for EmmyLua type annotations

View File

@ -646,6 +646,7 @@ function hideMessage() {
}
function showWaitAnimation() {
hideWaitAnimation();
$("#inputrowright").append("<img id=\"waitanim\" src=\"static/thinking.gif\"/>");
}

View File

@ -7,10 +7,10 @@
<script src="static/jquery-3.6.0.min.js"></script>
<script src="static/jquery-ui.sortable.min.js"></script>
<script src="static/socket.io.min.js"></script>
<script src="static/application.js?ver=1.16.4h"></script>
<script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.16.4i"></script>
<link rel="stylesheet" href="static/jquery-ui.sortable.min.css">
<link rel="stylesheet" href="static/bootstrap.min.css">