Lua API fixes

* `print()` and `warn()` now work correctly with `nil` arguments
* Typo: `gpt-neo-1.3M` has been corrected to `gpt-neo-1.3B`
* Regeneration is no longer triggered when writing to `keysecondary` of
  a non-selective key
* Handle `genamt` changes in generation modifier properly
* Writing to `kobold.settings.numseqs` from a generation modifier no
  longer affects
* Formatting options in `kobold.settings` have been fixed
* Added aliases for setting names
* Fix behaviour of editing story chunks from a generation modifier
* Warnings are now yellow instead of red
* kobold.logits is now the raw logits prior to being filtered, like
  the documentation says, rather than after being filtered
* Some erroneous comments and error messages have been corrected
* These parts of the API have now been implemented properly:
    * `compute_context()` methods
    * `kobold.authorsnote`
    * `kobold.restart_generation()`
This commit is contained in:
Gnome Ann
2021-12-19 20:18:28 -05:00
parent 4bb5e59d82
commit 341b153360
5 changed files with 420 additions and 189 deletions

View File

@ -111,8 +111,9 @@ class vars:
lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua
lua_koboldcore = None # `koboldcore` from bridge.lua
lua_warper = None # Transformers logits warper controllable from Lua
lua_logname = ... # Name of previous userscript that logged to terminal
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
userscripts = [] # List of userscripts to load
corescript = "default.lua" # Filename of corescript to load
# badwords = [] # Array of str/chr values that should be removed from output
@ -615,7 +616,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -658,7 +659,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class LuaLogitsWarper(LogitsWarper):
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
@ -686,6 +687,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
assert scores.shape == scores_shape
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
def new_get_logits_warper(
top_k: int = None,
@ -703,18 +711,16 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
if(temp is not None and temp != 1.0):
warper_list.append(TemperatureLogitsWarper(temperature=temp))
vars.lua_warper = LuaLogitsWarper()
warper_list.append(vars.lua_warper)
return warper_list
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = new_get_logits_warper(
vars.top_k,
vars.top_p,
vars.tfs,
vars.temp,
1,
top_k=vars.top_k,
top_p=vars.top_p,
tfs=vars.tfs,
temp=vars.temp,
beams=1,
)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
@ -1042,7 +1048,7 @@ def lua_warn(msg):
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
vars.lua_logname = vars.lua_koboldbridge.logging_name
print(colors.BLUE + lua_log_format_name(vars.lua_logname) + ":" + colors.END, file=sys.stderr)
print(colors.RED + msg.replace("\033", "") + colors.END)
print(colors.YELLOW + msg.replace("\033", "") + colors.END)
#==================================================================#
# Decode tokens into a string using current tokenizer
@ -1062,12 +1068,36 @@ def lua_decode(tokens):
def lua_encode(string):
assert type(string) is str
if("tokenizer" not in globals()):
thinking = False
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================#
# Computes context given a submission, Lua array of entry UIDs and a Lua array
# of folder UIDs
#==================================================================#
def lua_compute_context(submission, entries, folders):
assert type(submission) is str
actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions
allowed_entries = None
allowed_folders = None
if(entries is not None):
allowed_entries = set()
i = 1
while(entries[i] is not None):
allowed_entries.add(int(entries[i]))
i += 1
if(folders is not None):
allowed_folders = set()
i = 1
while(folders[i] is not None):
allowed_folders.add(int(folders[i]))
i += 1
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
return txt
#==================================================================#
# Get property of a world info entry given its UID and property name
#==================================================================#
@ -1162,22 +1192,38 @@ def lua_set_numseqs(numseqs):
#==================================================================#
def lua_has_setting(setting):
return setting in (
"anotedepth",
"settemp",
"settopp",
"settopk",
"settfs",
"setreppen",
"settknmax",
"anotedepth",
"setwidepth",
"setuseprompt",
"setadventure",
"setdynamicscan",
"setnopromptgen",
"temp",
"topp",
"topk",
"tfs",
"reppen",
"tknmax",
"widepth",
"useprompt",
"adventure",
"dynamicscan",
"nopromptgen",
"frmttriminc",
"frmtrmblln",
"frmtrmspch",
"frmtadsnsp",
"frmtsingleline",
"triminc",
"rmblln",
"rmspch",
"adsnsp",
"singleline",
)
@ -1185,23 +1231,23 @@ def lua_has_setting(setting):
# Return the setting with the given name if it exists
#==================================================================#
def lua_get_setting(setting):
if(setting == "settemp"): return vars.temp
if(setting == "settopp"): return vars.top_p
if(setting == "settopk"): return vars.top_k
if(setting == "settfs"): return vars.tfs
if(setting == "setreppen"): return vars.rep_pen
if(setting == "settknmax"): return vars.max_length
if(setting in ("settemp", "temp")): return vars.temp
if(setting in ("settopp", "topp")): return vars.top_p
if(setting in ("settopk", "topk")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("settknmax", "tknmax")): return vars.max_length
if(setting == "anotedepth"): return vars.andepth
if(setting == "setwidepth"): return vars.widepth
if(setting == "setuseprompt"): return vars.useprompt
if(setting == "setadventure"): return vars.adventure
if(setting == "setdynamicscan"): return vars.dynamicscan
if(setting == "nopromptgen"): return vars.nopromptgen
if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"]
if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"]
if(setting == "singleline"): return vars.formatoptns["frmttriminc"]
if(setting in ("setwidepth", "widepth")): return vars.widepth
if(setting in ("setuseprompt", "useprompt")): return vars.useprompt
if(setting in ("setadventure", "adventure")): return vars.adventure
if(setting in ("setdynamicscan", "dynamicscan")): return vars.dynamicscan
if(setting in ("setnopromptgen", "nopromptgen")): return vars.nopromptgen
if(setting in ("frmttriminc", "triminc")): return vars.formatoptns["frmttriminc"]
if(setting in ("frmtrmblln", "rmblln")): return vars.formatoptns["frmttrmblln"]
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
#==================================================================#
# Set the setting with the given name if it exists
@ -1211,25 +1257,25 @@ def lua_set_setting(setting, v):
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
v = actual_type(v)
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {setting} to {v}" + colors.END)
if(setting == "setadventure" and v):
if(setting in ("setadventure", "adventure") and v):
vars.actionmode = 1
if(setting == "settemp"): vars.temp = v
if(setting == "settopp"): vars.top_p = v
if(setting == "settopk"): vars.top_k = v
if(setting == "settfs"): vars.tfs = v
if(setting == "setreppen"): vars.rep_pen = v
if(setting == "settknmax"): vars.max_length = v
if(setting == "anotedepth"): vars.andepth = v
if(setting == "setwidepth"): vars.widepth = v
if(setting == "setuseprompt"): vars.useprompt = v
if(setting == "setadventure"): vars.adventure = v
if(setting == "setdynamicscan"): vars.dynamicscan = v
if(setting == "setnopromptgen"): vars.nopromptgen = v
if(setting == "frmttriminc"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtrmblln"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtrmspch"): vars.formatoptns["frmttriminc"] = v
if(setting == "frmtadsnsp"): vars.formatoptns["frmttriminc"] = v
if(setting == "singleline"): vars.formatoptns["frmttriminc"] = v
if(setting in ("settemp", "temp")): vars.temp = v
if(setting in ("settopp", "topp")): vars.top_p = v
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
if(setting == "anotedepth"): vars.andepth = v; return True
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
if(setting in ("setuseprompt", "useprompt")): vars.useprompt = v; return True
if(setting in ("setadventure", "adventure")): vars.adventure = v
if(setting in ("setdynamicscan", "dynamicscan")): vars.dynamicscan = v
if(setting in ("setnopromptgen", "nopromptgen")): vars.nopromptgen = v
if(setting in ("frmttriminc", "triminc")): vars.formatoptns["frmttriminc"] = v
if(setting in ("frmtrmblln", "rmblln")): vars.formatoptns["frmttrmblln"] = v
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
#==================================================================#
# Get contents of memory
@ -1244,6 +1290,19 @@ def lua_set_memory(m):
assert type(m) is str
vars.memory = m
#==================================================================#
# Get contents of author's note
#==================================================================#
def lua_get_authorsnote():
return vars.authornote
#==================================================================#
# Set contents of author's note
#==================================================================#
def lua_set_authorsnote(m):
assert type(m) is str
vars.authornote = m
#==================================================================#
# Save settings and send them to client
#==================================================================#
@ -1260,13 +1319,28 @@ def lua_set_chunk(k, v):
assert k != 0 or len(v) != 0
if(len(v) == 0):
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} deleted story chunk {k}" + colors.END)
inlinedelete(k)
chunk = int(k)
if(vars.lua_koboldbridge.userstate == "genmod"):
del vars._actions[chunk-1]
vars.lua_deleted.add(chunk)
if(vars._actions is not vars.actions):
del vars.actions[chunk-1]
else:
if(k == 0):
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited prompt chunk" + colors.END)
else:
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited story chunk {k}" + colors.END)
inlineedit(k, v)
chunk = int(k)
if(chunk == 0):
if(vars.lua_koboldbridge.userstate == "genmod"):
vars._prompt = v
vars.lua_edited.add(chunk)
vars.prompt = v
else:
if(vars.lua_koboldbridge.userstate == "genmod"):
vars._actions[chunk-1] = v
vars.lua_edited.add(chunk)
vars.actions[chunk-1] = v
#==================================================================#
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
@ -1288,8 +1362,8 @@ def lua_get_modeltype():
return "gpt2-xl"
if(vars.model == "NeoCustom" and hidden_size == 768):
return "gpt-neo-125M"
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
return "gpt-neo-1.3M"
if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
return "gpt-neo-1.3B"
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
return "gpt-neo-2.7B"
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
@ -1330,6 +1404,8 @@ def execute_inmod():
set_aibusy(0)
def execute_genmod():
vars.lua_edited = set()
vars.lua_deleted = set()
vars.lua_koboldbridge.execute_genmod()
def execute_outmod():
@ -1345,6 +1421,10 @@ def execute_outmod():
if(vars.lua_koboldbridge.resend_settings_required):
vars.lua_koboldbridge.resend_settings_required = False
lua_resend_settings()
for k in vars.lua_edited:
inlineedit(k, vars.actions[k])
for k in vars.lua_deleted:
inlinedelete(k)
#==================================================================#
# Lua runtime startup
@ -1373,6 +1453,9 @@ bridged = {
"set_genamt": lua_set_genamt,
"get_memory": lua_get_memory,
"set_memory": lua_set_memory,
"get_authorsnote": lua_get_authorsnote,
"set_authorsnote": lua_set_authorsnote,
"compute_context": lua_compute_context,
"get_numseqs": lua_get_numseqs,
"set_numseqs": lua_set_numseqs,
"has_setting": lua_has_setting,
@ -1455,6 +1538,7 @@ def get_message(msg):
# Submit action
if(msg['cmd'] == 'submit'):
if(vars.mode == "play"):
vars.lua_koboldbridge.feedback = None
actionsubmit(msg['data'], actionmode=msg['actionmode'])
elif(vars.mode == "edit"):
editsubmit(msg['data'])
@ -1873,83 +1957,51 @@ def actionsubmit(data, actionmode=0, force_submit=False):
# Ignore new submissions if the AI is currently busy
if(vars.aibusy):
return
set_aibusy(1)
vars.recentback = False
vars.recentedit = False
vars.actionmode = actionmode
while(True):
set_aibusy(1)
# "Action" mode
if(actionmode == 1):
data = data.strip().lstrip('>')
data = re.sub(r'\n+', ' ', data)
if(len(data)):
data = f"\n\n> {data}\n"
# If we're not continuing, store a copy of the raw input
if(data != ""):
vars.lastact = data
if(not vars.gamestarted):
vars.submission = data
execute_inmod()
data = vars.submission
if(not force_submit and len(data.strip()) == 0):
assert False
# Start the game
vars.gamestarted = True
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
# Save this first action as the prompt
vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
calcsubmit(data) # Run the first action through the generator
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
# Save this first action as the prompt
vars.prompt = data
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
refresh_story()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
# Apply input formatting & scripts before sending to tokenizer
if(vars.actionmode == 0):
data = applyinputformatting(data)
vars.submission = data
execute_inmod()
data = vars.submission
# Dont append submission if it's a blank/continue action
vars.recentback = False
vars.recentedit = False
vars.actionmode = actionmode
# "Action" mode
if(actionmode == 1):
data = data.strip().lstrip('>')
data = re.sub(r'\n+', ' ', data)
if(len(data)):
data = f"\n\n> {data}\n"
# If we're not continuing, store a copy of the raw input
if(data != ""):
# Store the result in the Action log
if(len(vars.prompt.strip()) == 0):
vars.lastact = data
if(not vars.gamestarted):
vars.submission = data
execute_inmod()
data = vars.submission
if(not force_submit and len(data.strip()) == 0):
assert False
# Start the game
vars.gamestarted = True
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
# Save this first action as the prompt
vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
calcsubmit(data) # Run the first action through the generator
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
data = ""
force_submit = True
continue
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
else:
vars.actions.append(data)
update_story_chunk('last')
if(not vars.noai and vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit(data)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
set_aibusy(0)
if(vars.lua_koboldbridge.regeneration_required):
# Save this first action as the prompt
vars.prompt = data
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
@ -1957,9 +2009,73 @@ def actionsubmit(data, actionmode=0, force_submit=False):
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
if(vars.lua_koboldbridge.restart_sequence is not None):
refresh_story()
data = ""
force_submit = True
continue
else:
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
refresh_story()
data = ""
force_submit = True
continue
genselect(genout)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
refresh_story()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
else:
# Apply input formatting & scripts before sending to tokenizer
if(vars.actionmode == 0):
data = applyinputformatting(data)
vars.submission = data
execute_inmod()
data = vars.submission
# Dont append submission if it's a blank/continue action
if(data != ""):
# Store the result in the Action log
if(len(vars.prompt.strip()) == 0):
vars.prompt = data
else:
vars.actions.append(data)
update_story_chunk('last')
if(not vars.noai and vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit(data)
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
data = ""
force_submit = True
continue
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
else:
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = ""
execute_outmod()
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
if(vars.lua_koboldbridge.restart_sequence is not None):
data = ""
force_submit = True
continue
else:
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
data = ""
force_submit = True
continue
genselect(genout)
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
#==================================================================#
#
@ -1972,18 +2088,14 @@ def actionretry(data):
return
# Remove last action if possible and resubmit
if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0):
set_aibusy(1)
if(not vars.recentback and len(vars.actions) != 0 and len(vars.genseqs) == 0): # Don't pop if we're in the "Select sequence to keep" menu or if there are no non-prompt actions
last_key = vars.actions.get_last_key()
vars.actions.pop()
remove_story_chunk(last_key + 1)
vars.genseqs = []
vars.submission = ""
execute_inmod()
calcsubmit(vars.submission)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
vars.recentback = False
vars.recentedit = False
vars.lua_koboldbridge.feedback = None
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
elif(not vars.useprompt):
emit('from_server', {'cmd': 'errmsg', 'data': "Please enable \"Always Add Prompt\" to retry with your prompt."})
@ -2223,9 +2335,10 @@ def _generate(txt, minimum, maximum, found_entries):
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
actions = actions.copy()
vars._actions = vars._actions.copy()
with torch.no_grad():
already_generated = 0
@ -2235,7 +2348,7 @@ def _generate(txt, minimum, maximum, found_entries):
gen_in,
do_sample=True,
min_length=minimum,
max_length=maximum-already_generated,
max_length=int(2e9),
repetition_penalty=vars.rep_pen,
bad_words_ids=vars.badwordsids,
use_cache=True,
@ -2257,7 +2370,7 @@ def _generate(txt, minimum, maximum, found_entries):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
@ -2333,7 +2446,10 @@ def generate(txt, minimum, maximum, found_entries=None):
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect(genout)
# Clear CUDA cache again if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
@ -2351,6 +2467,11 @@ def genresult(genout):
# Format output before continuing
genout = applyoutputformatting(genout)
vars.lua_koboldbridge.feedback = genout
if(len(genout) == 0):
return
# Add formatted text to Actions array and refresh the game screen
if(len(vars.prompt.strip()) == 0):
@ -2383,12 +2504,17 @@ def genselect(genout):
def selectsequence(n):
if(len(vars.genseqs) == 0):
return
vars.actions.append(vars.genseqs[int(n)]["generated_text"])
update_story_chunk('last')
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True)
vars.lua_koboldbridge.feedback = vars.genseqs[int(n)]["generated_text"]
if(len(vars.lua_koboldbridge.feedback) != 0):
vars.actions.append(vars.lua_koboldbridge.feedback)
update_story_chunk('last')
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True)
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
vars.genseqs = []
if(vars.lua_koboldbridge.restart_sequence is not None):
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
#==================================================================#
# Send transformers-style request to ngrok/colab host
#==================================================================#
@ -2447,7 +2573,10 @@ def sendtocolab(txt, min, max):
seqs = []
for seq in genout:
seqs.append({"generated_text": seq})
genselect(seqs)
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect(genout)
# Format output before continuing
#genout = applyoutputformatting(getnewcontent(genout))
@ -2544,7 +2673,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
else:
genselect(genout)
set_aibusy(0)
@ -2995,7 +3127,7 @@ def deletewifolder(uid):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt, force_use_txt=False):
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False):
original_txt = txt
# Dont go any further if WI is empty
@ -3038,6 +3170,11 @@ def checkworldinfo(txt, force_use_txt=False):
wimem = ""
found_entries = set()
for wi in vars.worldinfo:
if(allowed_entries is not None and wi["uid"] not in allowed_entries):
continue
if(allowed_folders is not None and wi["folder"] not in allowed_folders):
continue
if(wi.get("constant", False)):
wimem = wimem + wi["content"] + "\n"
found_entries.add(id(wi))
@ -3840,6 +3977,7 @@ def newGameRequest():
def randomGameRequest(topic):
newGameRequest()
vars.memory = "You generate the following " + topic + " story concept :"
vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True)
vars.memory = ""