mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Lua API fixes
* `print()` and `warn()` now work correctly with `nil` arguments * Typo: `gpt-neo-1.3M` has been corrected to `gpt-neo-1.3B` * Regeneration is no longer triggered when writing to `keysecondary` of a non-selective key * Handle `genamt` changes in generation modifier properly * Writing to `kobold.settings.numseqs` from a generation modifier no longer affects * Formatting options in `kobold.settings` have been fixed * Added aliases for setting names * Fix behaviour of editing story chunks from a generation modifier * Warnings are now yellow instead of red * kobold.logits is now the raw logits prior to being filtered, like the documentation says, rather than after being filtered * Some erroneous comments and error messages have been corrected * These parts of the API have now been implemented properly: * `compute_context()` methods * `kobold.authorsnote` * `kobold.restart_generation()`
This commit is contained in:
420
aiserver.py
420
aiserver.py
@ -111,8 +111,9 @@ class vars:
|
||||
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
||||
lua_kobold = None # `kobold` from` bridge.lua
|
||||
lua_koboldcore = None # `koboldcore` from bridge.lua
|
||||
lua_warper = None # Transformers logits warper controllable from Lua
|
||||
lua_logname = ... # Name of previous userscript that logged to terminal
|
||||
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
|
||||
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
|
||||
userscripts = [] # List of userscripts to load
|
||||
corescript = "default.lua" # Filename of corescript to load
|
||||
# badwords = [] # Array of str/chr values that should be removed from output
|
||||
@ -615,7 +616,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
@ -658,7 +659,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
class LuaLogitsWarper(LogitsWarper):
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -686,6 +687,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||
processors.insert(0, LuaLogitsProcessor())
|
||||
return processors
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
||||
def new_get_logits_warper(
|
||||
top_k: int = None,
|
||||
@ -703,18 +711,16 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
|
||||
if(temp is not None and temp != 1.0):
|
||||
warper_list.append(TemperatureLogitsWarper(temperature=temp))
|
||||
vars.lua_warper = LuaLogitsWarper()
|
||||
warper_list.append(vars.lua_warper)
|
||||
return warper_list
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = new_get_logits_warper(
|
||||
vars.top_k,
|
||||
vars.top_p,
|
||||
vars.tfs,
|
||||
vars.temp,
|
||||
1,
|
||||
top_k=vars.top_k,
|
||||
top_p=vars.top_p,
|
||||
tfs=vars.tfs,
|
||||
temp=vars.temp,
|
||||
beams=1,
|
||||
)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||
@ -1042,7 +1048,7 @@ def lua_warn(msg):
|
||||
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
|
||||
vars.lua_logname = vars.lua_koboldbridge.logging_name
|
||||
print(colors.BLUE + lua_log_format_name(vars.lua_logname) + ":" + colors.END, file=sys.stderr)
|
||||
print(colors.RED + msg.replace("\033", "") + colors.END)
|
||||
print(colors.YELLOW + msg.replace("\033", "") + colors.END)
|
||||
|
||||
#==================================================================#
|
||||
# Decode tokens into a string using current tokenizer
|
||||
@ -1062,12 +1068,36 @@ def lua_decode(tokens):
|
||||
def lua_encode(string):
|
||||
assert type(string) is str
|
||||
if("tokenizer" not in globals()):
|
||||
thinking = False
|
||||
from transformers import GPT2TokenizerFast
|
||||
global tokenizer
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
||||
|
||||
#==================================================================#
|
||||
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
||||
# of folder UIDs
|
||||
#==================================================================#
|
||||
def lua_compute_context(submission, entries, folders):
|
||||
assert type(submission) is str
|
||||
actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions
|
||||
allowed_entries = None
|
||||
allowed_folders = None
|
||||
if(entries is not None):
|
||||
allowed_entries = set()
|
||||
i = 1
|
||||
while(entries[i] is not None):
|
||||
allowed_entries.add(int(entries[i]))
|
||||
i += 1
|
||||
if(folders is not None):
|
||||
allowed_folders = set()
|
||||
i = 1
|
||||
while(folders[i] is not None):
|
||||
allowed_folders.add(int(folders[i]))
|
||||
i += 1
|
||||
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
return txt
|
||||
|
||||
#==================================================================#
|
||||
# Get property of a world info entry given its UID and property name
|
||||
#==================================================================#
|
||||
@ -1162,22 +1192,38 @@ def lua_set_numseqs(numseqs):
|
||||
#==================================================================#
|
||||
def lua_has_setting(setting):
|
||||
return setting in (
|
||||
"anotedepth",
|
||||
"settemp",
|
||||
"settopp",
|
||||
"settopk",
|
||||
"settfs",
|
||||
"setreppen",
|
||||
"settknmax",
|
||||
"anotedepth",
|
||||
"setwidepth",
|
||||
"setuseprompt",
|
||||
"setadventure",
|
||||
"setdynamicscan",
|
||||
"setnopromptgen",
|
||||
"temp",
|
||||
"topp",
|
||||
"topk",
|
||||
"tfs",
|
||||
"reppen",
|
||||
"tknmax",
|
||||
"widepth",
|
||||
"useprompt",
|
||||
"adventure",
|
||||
"dynamicscan",
|
||||
"nopromptgen",
|
||||
"frmttriminc",
|
||||
"frmtrmblln",
|
||||
"frmtrmspch",
|
||||
"frmtadsnsp",
|
||||
"frmtsingleline",
|
||||
"triminc",
|
||||
"rmblln",
|
||||
"rmspch",
|
||||
"adsnsp",
|
||||
"singleline",
|
||||
)
|
||||
|
||||
@ -1185,23 +1231,23 @@ def lua_has_setting(setting):
|
||||
# Return the setting with the given name if it exists
|
||||
#==================================================================#
|
||||
def lua_get_setting(setting):
|
||||
if(setting == "settemp"): return vars.temp
|
||||
if(setting == "settopp"): return vars.top_p
|
||||
if(setting == "settopk"): return vars.top_k
|
||||
if(setting == "settfs"): return vars.tfs
|
||||
if(setting == "setreppen"): return vars.rep_pen
|
||||
if(setting == "settknmax"): return vars.max_length
|
||||
if(setting in ("settemp", "temp")): return vars.temp
|
||||
if(setting in ("settopp", "topp")): return vars.top_p
|
||||
if(setting in ("settopk", "topk")): return vars.top_k
|
||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||
if(setting in ("settknmax", "tknmax")): return vars.max_length
|
||||
if(setting == "anotedepth"): return vars.andepth
|
||||
if(setting == "setwidepth"): return vars.widepth
|
||||
if(setting == "setuseprompt"): return vars.useprompt
|
||||
if(setting == "setadventure"): return vars.adventure
|
||||
if(setting == "setdynamicscan"): return vars.dynamicscan
|
||||
if(setting == "nopromptgen"): return vars.nopromptgen
|
||||
if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting == "singleline"): return vars.formatoptns["frmttriminc"]
|
||||
if(setting in ("setwidepth", "widepth")): return vars.widepth
|
||||
if(setting in ("setuseprompt", "useprompt")): return vars.useprompt
|
||||
if(setting in ("setadventure", "adventure")): return vars.adventure
|
||||
if(setting in ("setdynamicscan", "dynamicscan")): return vars.dynamicscan
|
||||
if(setting in ("setnopromptgen", "nopromptgen")): return vars.nopromptgen
|
||||
if(setting in ("frmttriminc", "triminc")): return vars.formatoptns["frmttriminc"]
|
||||
if(setting in ("frmtrmblln", "rmblln")): return vars.formatoptns["frmttrmblln"]
|
||||
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
|
||||
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
|
||||
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
|
||||
|
||||
#==================================================================#
|
||||
# Set the setting with the given name if it exists
|
||||
@ -1211,25 +1257,25 @@ def lua_set_setting(setting, v):
|
||||
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
|
||||
v = actual_type(v)
|
||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {setting} to {v}" + colors.END)
|
||||
if(setting == "setadventure" and v):
|
||||
if(setting in ("setadventure", "adventure") and v):
|
||||
vars.actionmode = 1
|
||||
if(setting == "settemp"): vars.temp = v
|
||||
if(setting == "settopp"): vars.top_p = v
|
||||
if(setting == "settopk"): vars.top_k = v
|
||||
if(setting == "settfs"): vars.tfs = v
|
||||
if(setting == "setreppen"): vars.rep_pen = v
|
||||
if(setting == "settknmax"): vars.max_length = v
|
||||
if(setting == "anotedepth"): vars.andepth = v
|
||||
if(setting == "setwidepth"): vars.widepth = v
|
||||
if(setting == "setuseprompt"): vars.useprompt = v
|
||||
if(setting == "setadventure"): vars.adventure = v
|
||||
if(setting == "setdynamicscan"): vars.dynamicscan = v
|
||||
if(setting == "setnopromptgen"): vars.nopromptgen = v
|
||||
if(setting == "frmttriminc"): vars.formatoptns["frmttriminc"] = v
|
||||
if(setting == "frmtrmblln"): vars.formatoptns["frmttriminc"] = v
|
||||
if(setting == "frmtrmspch"): vars.formatoptns["frmttriminc"] = v
|
||||
if(setting == "frmtadsnsp"): vars.formatoptns["frmttriminc"] = v
|
||||
if(setting == "singleline"): vars.formatoptns["frmttriminc"] = v
|
||||
if(setting in ("settemp", "temp")): vars.temp = v
|
||||
if(setting in ("settopp", "topp")): vars.top_p = v
|
||||
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
|
||||
if(setting == "anotedepth"): vars.andepth = v; return True
|
||||
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
|
||||
if(setting in ("setuseprompt", "useprompt")): vars.useprompt = v; return True
|
||||
if(setting in ("setadventure", "adventure")): vars.adventure = v
|
||||
if(setting in ("setdynamicscan", "dynamicscan")): vars.dynamicscan = v
|
||||
if(setting in ("setnopromptgen", "nopromptgen")): vars.nopromptgen = v
|
||||
if(setting in ("frmttriminc", "triminc")): vars.formatoptns["frmttriminc"] = v
|
||||
if(setting in ("frmtrmblln", "rmblln")): vars.formatoptns["frmttrmblln"] = v
|
||||
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
|
||||
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
|
||||
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
|
||||
|
||||
#==================================================================#
|
||||
# Get contents of memory
|
||||
@ -1244,6 +1290,19 @@ def lua_set_memory(m):
|
||||
assert type(m) is str
|
||||
vars.memory = m
|
||||
|
||||
#==================================================================#
|
||||
# Get contents of author's note
|
||||
#==================================================================#
|
||||
def lua_get_authorsnote():
|
||||
return vars.authornote
|
||||
|
||||
#==================================================================#
|
||||
# Set contents of author's note
|
||||
#==================================================================#
|
||||
def lua_set_authorsnote(m):
|
||||
assert type(m) is str
|
||||
vars.authornote = m
|
||||
|
||||
#==================================================================#
|
||||
# Save settings and send them to client
|
||||
#==================================================================#
|
||||
@ -1260,13 +1319,28 @@ def lua_set_chunk(k, v):
|
||||
assert k != 0 or len(v) != 0
|
||||
if(len(v) == 0):
|
||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} deleted story chunk {k}" + colors.END)
|
||||
inlinedelete(k)
|
||||
chunk = int(k)
|
||||
if(vars.lua_koboldbridge.userstate == "genmod"):
|
||||
del vars._actions[chunk-1]
|
||||
vars.lua_deleted.add(chunk)
|
||||
if(vars._actions is not vars.actions):
|
||||
del vars.actions[chunk-1]
|
||||
else:
|
||||
if(k == 0):
|
||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited prompt chunk" + colors.END)
|
||||
else:
|
||||
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited story chunk {k}" + colors.END)
|
||||
inlineedit(k, v)
|
||||
chunk = int(k)
|
||||
if(chunk == 0):
|
||||
if(vars.lua_koboldbridge.userstate == "genmod"):
|
||||
vars._prompt = v
|
||||
vars.lua_edited.add(chunk)
|
||||
vars.prompt = v
|
||||
else:
|
||||
if(vars.lua_koboldbridge.userstate == "genmod"):
|
||||
vars._actions[chunk-1] = v
|
||||
vars.lua_edited.add(chunk)
|
||||
vars.actions[chunk-1] = v
|
||||
|
||||
#==================================================================#
|
||||
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
|
||||
@ -1288,8 +1362,8 @@ def lua_get_modeltype():
|
||||
return "gpt2-xl"
|
||||
if(vars.model == "NeoCustom" and hidden_size == 768):
|
||||
return "gpt-neo-125M"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
|
||||
return "gpt-neo-1.3M"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model == "NeoCustom" and hidden_size == 2048)):
|
||||
return "gpt-neo-1.3B"
|
||||
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)):
|
||||
return "gpt-neo-2.7B"
|
||||
if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)):
|
||||
@ -1330,6 +1404,8 @@ def execute_inmod():
|
||||
set_aibusy(0)
|
||||
|
||||
def execute_genmod():
|
||||
vars.lua_edited = set()
|
||||
vars.lua_deleted = set()
|
||||
vars.lua_koboldbridge.execute_genmod()
|
||||
|
||||
def execute_outmod():
|
||||
@ -1345,6 +1421,10 @@ def execute_outmod():
|
||||
if(vars.lua_koboldbridge.resend_settings_required):
|
||||
vars.lua_koboldbridge.resend_settings_required = False
|
||||
lua_resend_settings()
|
||||
for k in vars.lua_edited:
|
||||
inlineedit(k, vars.actions[k])
|
||||
for k in vars.lua_deleted:
|
||||
inlinedelete(k)
|
||||
|
||||
#==================================================================#
|
||||
# Lua runtime startup
|
||||
@ -1373,6 +1453,9 @@ bridged = {
|
||||
"set_genamt": lua_set_genamt,
|
||||
"get_memory": lua_get_memory,
|
||||
"set_memory": lua_set_memory,
|
||||
"get_authorsnote": lua_get_authorsnote,
|
||||
"set_authorsnote": lua_set_authorsnote,
|
||||
"compute_context": lua_compute_context,
|
||||
"get_numseqs": lua_get_numseqs,
|
||||
"set_numseqs": lua_set_numseqs,
|
||||
"has_setting": lua_has_setting,
|
||||
@ -1455,6 +1538,7 @@ def get_message(msg):
|
||||
# Submit action
|
||||
if(msg['cmd'] == 'submit'):
|
||||
if(vars.mode == "play"):
|
||||
vars.lua_koboldbridge.feedback = None
|
||||
actionsubmit(msg['data'], actionmode=msg['actionmode'])
|
||||
elif(vars.mode == "edit"):
|
||||
editsubmit(msg['data'])
|
||||
@ -1873,83 +1957,51 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
||||
# Ignore new submissions if the AI is currently busy
|
||||
if(vars.aibusy):
|
||||
return
|
||||
set_aibusy(1)
|
||||
|
||||
vars.recentback = False
|
||||
vars.recentedit = False
|
||||
vars.actionmode = actionmode
|
||||
while(True):
|
||||
set_aibusy(1)
|
||||
|
||||
# "Action" mode
|
||||
if(actionmode == 1):
|
||||
data = data.strip().lstrip('>')
|
||||
data = re.sub(r'\n+', ' ', data)
|
||||
if(len(data)):
|
||||
data = f"\n\n> {data}\n"
|
||||
|
||||
# If we're not continuing, store a copy of the raw input
|
||||
if(data != ""):
|
||||
vars.lastact = data
|
||||
|
||||
if(not vars.gamestarted):
|
||||
vars.submission = data
|
||||
execute_inmod()
|
||||
data = vars.submission
|
||||
if(not force_submit and len(data.strip()) == 0):
|
||||
assert False
|
||||
# Start the game
|
||||
vars.gamestarted = True
|
||||
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
|
||||
# Save this first action as the prompt
|
||||
vars.prompt = data
|
||||
# Clear the startup text from game screen
|
||||
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
|
||||
calcsubmit(data) # Run the first action through the generator
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
else:
|
||||
# Save this first action as the prompt
|
||||
vars.prompt = data
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(vars.numseqs):
|
||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
refresh_story()
|
||||
set_aibusy(0)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
else:
|
||||
# Apply input formatting & scripts before sending to tokenizer
|
||||
if(vars.actionmode == 0):
|
||||
data = applyinputformatting(data)
|
||||
vars.submission = data
|
||||
execute_inmod()
|
||||
data = vars.submission
|
||||
# Dont append submission if it's a blank/continue action
|
||||
vars.recentback = False
|
||||
vars.recentedit = False
|
||||
vars.actionmode = actionmode
|
||||
|
||||
# "Action" mode
|
||||
if(actionmode == 1):
|
||||
data = data.strip().lstrip('>')
|
||||
data = re.sub(r'\n+', ' ', data)
|
||||
if(len(data)):
|
||||
data = f"\n\n> {data}\n"
|
||||
|
||||
# If we're not continuing, store a copy of the raw input
|
||||
if(data != ""):
|
||||
# Store the result in the Action log
|
||||
if(len(vars.prompt.strip()) == 0):
|
||||
vars.lastact = data
|
||||
|
||||
if(not vars.gamestarted):
|
||||
vars.submission = data
|
||||
execute_inmod()
|
||||
data = vars.submission
|
||||
if(not force_submit and len(data.strip()) == 0):
|
||||
assert False
|
||||
# Start the game
|
||||
vars.gamestarted = True
|
||||
if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen):
|
||||
# Save this first action as the prompt
|
||||
vars.prompt = data
|
||||
# Clear the startup text from game screen
|
||||
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
|
||||
calcsubmit(data) # Run the first action through the generator
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||
data = ""
|
||||
force_submit = True
|
||||
continue
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
break
|
||||
else:
|
||||
vars.actions.append(data)
|
||||
update_story_chunk('last')
|
||||
|
||||
if(not vars.noai and vars.lua_koboldbridge.generating):
|
||||
# Off to the tokenizer!
|
||||
calcsubmit(data)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
else:
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
set_aibusy(0)
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
# Save this first action as the prompt
|
||||
vars.prompt = data
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(vars.numseqs):
|
||||
@ -1957,9 +2009,73 @@ def actionsubmit(data, actionmode=0, force_submit=False):
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None):
|
||||
refresh_story()
|
||||
data = ""
|
||||
force_submit = True
|
||||
continue
|
||||
else:
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
refresh_story()
|
||||
data = ""
|
||||
force_submit = True
|
||||
continue
|
||||
genselect(genout)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
refresh_story()
|
||||
set_aibusy(0)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
break
|
||||
else:
|
||||
# Apply input formatting & scripts before sending to tokenizer
|
||||
if(vars.actionmode == 0):
|
||||
data = applyinputformatting(data)
|
||||
vars.submission = data
|
||||
execute_inmod()
|
||||
data = vars.submission
|
||||
# Dont append submission if it's a blank/continue action
|
||||
if(data != ""):
|
||||
# Store the result in the Action log
|
||||
if(len(vars.prompt.strip()) == 0):
|
||||
vars.prompt = data
|
||||
else:
|
||||
vars.actions.append(data)
|
||||
update_story_chunk('last')
|
||||
|
||||
if(not vars.noai and vars.lua_koboldbridge.generating):
|
||||
# Off to the tokenizer!
|
||||
calcsubmit(data)
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||
data = ""
|
||||
force_submit = True
|
||||
continue
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
break
|
||||
else:
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(vars.numseqs):
|
||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None):
|
||||
data = ""
|
||||
force_submit = True
|
||||
continue
|
||||
else:
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
data = ""
|
||||
force_submit = True
|
||||
continue
|
||||
genselect(genout)
|
||||
set_aibusy(0)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
break
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
@ -1972,18 +2088,14 @@ def actionretry(data):
|
||||
return
|
||||
# Remove last action if possible and resubmit
|
||||
if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0):
|
||||
set_aibusy(1)
|
||||
if(not vars.recentback and len(vars.actions) != 0 and len(vars.genseqs) == 0): # Don't pop if we're in the "Select sequence to keep" menu or if there are no non-prompt actions
|
||||
last_key = vars.actions.get_last_key()
|
||||
vars.actions.pop()
|
||||
remove_story_chunk(last_key + 1)
|
||||
vars.genseqs = []
|
||||
vars.submission = ""
|
||||
execute_inmod()
|
||||
calcsubmit(vars.submission)
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
vars.recentback = False
|
||||
vars.recentedit = False
|
||||
vars.lua_koboldbridge.feedback = None
|
||||
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
|
||||
elif(not vars.useprompt):
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': "Please enable \"Always Add Prompt\" to retry with your prompt."})
|
||||
|
||||
@ -2223,9 +2335,10 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
model.kai_scanner_head_length = gen_in.shape[-1]
|
||||
model.kai_scanner_excluded_world_info = found_entries
|
||||
|
||||
actions = vars.actions
|
||||
vars._actions = vars.actions
|
||||
vars._prompt = vars.prompt
|
||||
if(vars.dynamicscan):
|
||||
actions = actions.copy()
|
||||
vars._actions = vars._actions.copy()
|
||||
|
||||
with torch.no_grad():
|
||||
already_generated = 0
|
||||
@ -2235,7 +2348,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=minimum,
|
||||
max_length=maximum-already_generated,
|
||||
max_length=int(2e9),
|
||||
repetition_penalty=vars.rep_pen,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
@ -2257,7 +2370,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions)
|
||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
@ -2333,7 +2446,10 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
|
||||
# Clear CUDA cache again if using GPU
|
||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||
@ -2351,6 +2467,11 @@ def genresult(genout):
|
||||
|
||||
# Format output before continuing
|
||||
genout = applyoutputformatting(genout)
|
||||
|
||||
vars.lua_koboldbridge.feedback = genout
|
||||
|
||||
if(len(genout) == 0):
|
||||
return
|
||||
|
||||
# Add formatted text to Actions array and refresh the game screen
|
||||
if(len(vars.prompt.strip()) == 0):
|
||||
@ -2383,12 +2504,17 @@ def genselect(genout):
|
||||
def selectsequence(n):
|
||||
if(len(vars.genseqs) == 0):
|
||||
return
|
||||
vars.actions.append(vars.genseqs[int(n)]["generated_text"])
|
||||
update_story_chunk('last')
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True)
|
||||
vars.lua_koboldbridge.feedback = vars.genseqs[int(n)]["generated_text"]
|
||||
if(len(vars.lua_koboldbridge.feedback) != 0):
|
||||
vars.actions.append(vars.lua_koboldbridge.feedback)
|
||||
update_story_chunk('last')
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
|
||||
vars.genseqs = []
|
||||
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None):
|
||||
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
|
||||
|
||||
#==================================================================#
|
||||
# Send transformers-style request to ngrok/colab host
|
||||
#==================================================================#
|
||||
@ -2447,7 +2573,10 @@ def sendtocolab(txt, min, max):
|
||||
seqs = []
|
||||
for seq in genout:
|
||||
seqs.append({"generated_text": seq})
|
||||
genselect(seqs)
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
|
||||
# Format output before continuing
|
||||
#genout = applyoutputformatting(getnewcontent(genout))
|
||||
@ -2544,7 +2673,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
else:
|
||||
genselect(genout)
|
||||
|
||||
set_aibusy(0)
|
||||
|
||||
@ -2995,7 +3127,7 @@ def deletewifolder(uid):
|
||||
#==================================================================#
|
||||
# Look for WI keys in text to generator
|
||||
#==================================================================#
|
||||
def checkworldinfo(txt, force_use_txt=False):
|
||||
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False):
|
||||
original_txt = txt
|
||||
|
||||
# Dont go any further if WI is empty
|
||||
@ -3038,6 +3170,11 @@ def checkworldinfo(txt, force_use_txt=False):
|
||||
wimem = ""
|
||||
found_entries = set()
|
||||
for wi in vars.worldinfo:
|
||||
if(allowed_entries is not None and wi["uid"] not in allowed_entries):
|
||||
continue
|
||||
if(allowed_folders is not None and wi["folder"] not in allowed_folders):
|
||||
continue
|
||||
|
||||
if(wi.get("constant", False)):
|
||||
wimem = wimem + wi["content"] + "\n"
|
||||
found_entries.add(id(wi))
|
||||
@ -3840,6 +3977,7 @@ def newGameRequest():
|
||||
def randomGameRequest(topic):
|
||||
newGameRequest()
|
||||
vars.memory = "You generate the following " + topic + " story concept :"
|
||||
vars.lua_koboldbridge.feedback = None
|
||||
actionsubmit("", force_submit=True)
|
||||
vars.memory = ""
|
||||
|
||||
|
Reference in New Issue
Block a user