diff --git a/aiserver.py b/aiserver.py index e1db156e..15d7e539 100644 --- a/aiserver.py +++ b/aiserver.py @@ -111,8 +111,9 @@ class vars: lua_koboldbridge = None # `koboldbridge` from bridge.lua lua_kobold = None # `kobold` from` bridge.lua lua_koboldcore = None # `koboldcore` from bridge.lua - lua_warper = None # Transformers logits warper controllable from Lua lua_logname = ... # Name of previous userscript that logged to terminal + lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier + lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier userscripts = [] # List of userscripts to load corescript = "default.lua" # Filename of corescript to load # badwords = [] # Array of str/chr values that should be removed from output @@ -615,7 +616,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # Patch transformers to use our custom logit warpers - from transformers import LogitsProcessorList, LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper + from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper class TailFreeLogitsWarper(LogitsWarper): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -658,7 +659,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores - class LuaLogitsWarper(LogitsWarper): + class LuaLogitsProcessor(LogitsProcessor): def __init__(self): pass @@ -686,6 +687,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme assert scores.shape == scores_shape return scores + + def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: + processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) + processors.insert(0, LuaLogitsProcessor()) + return processors + new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor + transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor def new_get_logits_warper( top_k: int = None, @@ -703,18 +711,16 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1))) if(temp is not None and temp != 1.0): warper_list.append(TemperatureLogitsWarper(temperature=temp)) - vars.lua_warper = LuaLogitsWarper() - warper_list.append(vars.lua_warper) return warper_list def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None kwargs["logits_warper"] = new_get_logits_warper( - vars.top_k, - vars.top_p, - vars.tfs, - vars.temp, - 1, + top_k=vars.top_k, + top_p=vars.top_p, + tfs=vars.tfs, + temp=vars.temp, + beams=1, ) return new_sample.old_sample(self, *args, **kwargs) new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample @@ -1042,7 +1048,7 @@ def lua_warn(msg): if(vars.lua_logname != vars.lua_koboldbridge.logging_name): vars.lua_logname = vars.lua_koboldbridge.logging_name print(colors.BLUE + lua_log_format_name(vars.lua_logname) + ":" + colors.END, file=sys.stderr) - print(colors.RED + msg.replace("\033", "") + colors.END) + print(colors.YELLOW + msg.replace("\033", "") + colors.END) #==================================================================# # Decode tokens into a string using current tokenizer @@ -1062,12 +1068,36 @@ def lua_decode(tokens): def lua_encode(string): assert type(string) is str if("tokenizer" not in globals()): - thinking = False from transformers import GPT2TokenizerFast global tokenizer tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") return tokenizer.encode(string, max_length=int(4e9), truncation=True) +#==================================================================# +# Computes context given a submission, Lua array of entry UIDs and a Lua array +# of folder UIDs +#==================================================================# +def lua_compute_context(submission, entries, folders): + assert type(submission) is str + actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions + allowed_entries = None + allowed_folders = None + if(entries is not None): + allowed_entries = set() + i = 1 + while(entries[i] is not None): + allowed_entries.add(int(entries[i])) + i += 1 + if(folders is not None): + allowed_folders = set() + i = 1 + while(folders[i] is not None): + allowed_folders.add(int(folders[i])) + i += 1 + winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True) + txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) + return txt + #==================================================================# # Get property of a world info entry given its UID and property name #==================================================================# @@ -1162,22 +1192,38 @@ def lua_set_numseqs(numseqs): #==================================================================# def lua_has_setting(setting): return setting in ( + "anotedepth", "settemp", "settopp", "settopk", "settfs", "setreppen", "settknmax", - "anotedepth", "setwidepth", "setuseprompt", "setadventure", "setdynamicscan", + "setnopromptgen", + "temp", + "topp", + "topk", + "tfs", + "reppen", + "tknmax", + "widepth", + "useprompt", + "adventure", + "dynamicscan", "nopromptgen", "frmttriminc", "frmtrmblln", "frmtrmspch", "frmtadsnsp", + "frmtsingleline", + "triminc", + "rmblln", + "rmspch", + "adsnsp", "singleline", ) @@ -1185,23 +1231,23 @@ def lua_has_setting(setting): # Return the setting with the given name if it exists #==================================================================# def lua_get_setting(setting): - if(setting == "settemp"): return vars.temp - if(setting == "settopp"): return vars.top_p - if(setting == "settopk"): return vars.top_k - if(setting == "settfs"): return vars.tfs - if(setting == "setreppen"): return vars.rep_pen - if(setting == "settknmax"): return vars.max_length + if(setting in ("settemp", "temp")): return vars.temp + if(setting in ("settopp", "topp")): return vars.top_p + if(setting in ("settopk", "topk")): return vars.top_k + if(setting in ("settfs", "tfs")): return vars.tfs + if(setting in ("setreppen", "reppen")): return vars.rep_pen + if(setting in ("settknmax", "tknmax")): return vars.max_length if(setting == "anotedepth"): return vars.andepth - if(setting == "setwidepth"): return vars.widepth - if(setting == "setuseprompt"): return vars.useprompt - if(setting == "setadventure"): return vars.adventure - if(setting == "setdynamicscan"): return vars.dynamicscan - if(setting == "nopromptgen"): return vars.nopromptgen - if(setting == "frmttriminc"): return vars.formatoptns["frmttriminc"] - if(setting == "frmtrmblln"): return vars.formatoptns["frmttriminc"] - if(setting == "frmtrmspch"): return vars.formatoptns["frmttriminc"] - if(setting == "frmtadsnsp"): return vars.formatoptns["frmttriminc"] - if(setting == "singleline"): return vars.formatoptns["frmttriminc"] + if(setting in ("setwidepth", "widepth")): return vars.widepth + if(setting in ("setuseprompt", "useprompt")): return vars.useprompt + if(setting in ("setadventure", "adventure")): return vars.adventure + if(setting in ("setdynamicscan", "dynamicscan")): return vars.dynamicscan + if(setting in ("setnopromptgen", "nopromptgen")): return vars.nopromptgen + if(setting in ("frmttriminc", "triminc")): return vars.formatoptns["frmttriminc"] + if(setting in ("frmtrmblln", "rmblln")): return vars.formatoptns["frmttrmblln"] + if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"] + if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"] + if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"] #==================================================================# # Set the setting with the given name if it exists @@ -1211,25 +1257,25 @@ def lua_set_setting(setting, v): assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float)) v = actual_type(v) print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {setting} to {v}" + colors.END) - if(setting == "setadventure" and v): + if(setting in ("setadventure", "adventure") and v): vars.actionmode = 1 - if(setting == "settemp"): vars.temp = v - if(setting == "settopp"): vars.top_p = v - if(setting == "settopk"): vars.top_k = v - if(setting == "settfs"): vars.tfs = v - if(setting == "setreppen"): vars.rep_pen = v - if(setting == "settknmax"): vars.max_length = v - if(setting == "anotedepth"): vars.andepth = v - if(setting == "setwidepth"): vars.widepth = v - if(setting == "setuseprompt"): vars.useprompt = v - if(setting == "setadventure"): vars.adventure = v - if(setting == "setdynamicscan"): vars.dynamicscan = v - if(setting == "setnopromptgen"): vars.nopromptgen = v - if(setting == "frmttriminc"): vars.formatoptns["frmttriminc"] = v - if(setting == "frmtrmblln"): vars.formatoptns["frmttriminc"] = v - if(setting == "frmtrmspch"): vars.formatoptns["frmttriminc"] = v - if(setting == "frmtadsnsp"): vars.formatoptns["frmttriminc"] = v - if(setting == "singleline"): vars.formatoptns["frmttriminc"] = v + if(setting in ("settemp", "temp")): vars.temp = v + if(setting in ("settopp", "topp")): vars.top_p = v + if(setting in ("settopk", "topk")): vars.top_k = v + if(setting in ("settfs", "tfs")): vars.tfs = v + if(setting in ("setreppen", "reppen")): vars.rep_pen = v + if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True + if(setting == "anotedepth"): vars.andepth = v; return True + if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True + if(setting in ("setuseprompt", "useprompt")): vars.useprompt = v; return True + if(setting in ("setadventure", "adventure")): vars.adventure = v + if(setting in ("setdynamicscan", "dynamicscan")): vars.dynamicscan = v + if(setting in ("setnopromptgen", "nopromptgen")): vars.nopromptgen = v + if(setting in ("frmttriminc", "triminc")): vars.formatoptns["frmttriminc"] = v + if(setting in ("frmtrmblln", "rmblln")): vars.formatoptns["frmttrmblln"] = v + if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v + if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v + if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v #==================================================================# # Get contents of memory @@ -1244,6 +1290,19 @@ def lua_set_memory(m): assert type(m) is str vars.memory = m +#==================================================================# +# Get contents of author's note +#==================================================================# +def lua_get_authorsnote(): + return vars.authornote + +#==================================================================# +# Set contents of author's note +#==================================================================# +def lua_set_authorsnote(m): + assert type(m) is str + vars.authornote = m + #==================================================================# # Save settings and send them to client #==================================================================# @@ -1260,13 +1319,28 @@ def lua_set_chunk(k, v): assert k != 0 or len(v) != 0 if(len(v) == 0): print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} deleted story chunk {k}" + colors.END) - inlinedelete(k) + chunk = int(k) + if(vars.lua_koboldbridge.userstate == "genmod"): + del vars._actions[chunk-1] + vars.lua_deleted.add(chunk) + if(vars._actions is not vars.actions): + del vars.actions[chunk-1] else: if(k == 0): print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited prompt chunk" + colors.END) else: print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited story chunk {k}" + colors.END) - inlineedit(k, v) + chunk = int(k) + if(chunk == 0): + if(vars.lua_koboldbridge.userstate == "genmod"): + vars._prompt = v + vars.lua_edited.add(chunk) + vars.prompt = v + else: + if(vars.lua_koboldbridge.userstate == "genmod"): + vars._actions[chunk-1] = v + vars.lua_edited.add(chunk) + vars.actions[chunk-1] = v #==================================================================# # Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc. @@ -1288,8 +1362,8 @@ def lua_get_modeltype(): return "gpt2-xl" if(vars.model == "NeoCustom" and hidden_size == 768): return "gpt-neo-125M" - if(vars.model in ("EleutherAI/gpt-neo-1.3M",) or (vars.model == "NeoCustom" and hidden_size == 2048)): - return "gpt-neo-1.3M" + if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model == "NeoCustom" and hidden_size == 2048)): + return "gpt-neo-1.3B" if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model == "NeoCustom" and hidden_size == 2560)): return "gpt-neo-2.7B" if(vars.model in ("EleutherAI/gpt-j-6B",) or (vars.model == "NeoCustom" and hidden_size == 4096) or (vars.model == "TPUMeshTransformerGPTJ" and tpu_mtj_backend.params["d_model"] == 4096)): @@ -1330,6 +1404,8 @@ def execute_inmod(): set_aibusy(0) def execute_genmod(): + vars.lua_edited = set() + vars.lua_deleted = set() vars.lua_koboldbridge.execute_genmod() def execute_outmod(): @@ -1345,6 +1421,10 @@ def execute_outmod(): if(vars.lua_koboldbridge.resend_settings_required): vars.lua_koboldbridge.resend_settings_required = False lua_resend_settings() + for k in vars.lua_edited: + inlineedit(k, vars.actions[k]) + for k in vars.lua_deleted: + inlinedelete(k) #==================================================================# # Lua runtime startup @@ -1373,6 +1453,9 @@ bridged = { "set_genamt": lua_set_genamt, "get_memory": lua_get_memory, "set_memory": lua_set_memory, + "get_authorsnote": lua_get_authorsnote, + "set_authorsnote": lua_set_authorsnote, + "compute_context": lua_compute_context, "get_numseqs": lua_get_numseqs, "set_numseqs": lua_set_numseqs, "has_setting": lua_has_setting, @@ -1455,6 +1538,7 @@ def get_message(msg): # Submit action if(msg['cmd'] == 'submit'): if(vars.mode == "play"): + vars.lua_koboldbridge.feedback = None actionsubmit(msg['data'], actionmode=msg['actionmode']) elif(vars.mode == "edit"): editsubmit(msg['data']) @@ -1873,83 +1957,51 @@ def actionsubmit(data, actionmode=0, force_submit=False): # Ignore new submissions if the AI is currently busy if(vars.aibusy): return - set_aibusy(1) - vars.recentback = False - vars.recentedit = False - vars.actionmode = actionmode + while(True): + set_aibusy(1) - # "Action" mode - if(actionmode == 1): - data = data.strip().lstrip('>') - data = re.sub(r'\n+', ' ', data) - if(len(data)): - data = f"\n\n> {data}\n" - - # If we're not continuing, store a copy of the raw input - if(data != ""): - vars.lastact = data - - if(not vars.gamestarted): - vars.submission = data - execute_inmod() - data = vars.submission - if(not force_submit and len(data.strip()) == 0): - assert False - # Start the game - vars.gamestarted = True - if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen): - # Save this first action as the prompt - vars.prompt = data - # Clear the startup text from game screen - emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True) - calcsubmit(data) # Run the first action through the generator - emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) - else: - # Save this first action as the prompt - vars.prompt = data - for i in range(vars.numseqs): - vars.lua_koboldbridge.outputs[i+1] = "" - execute_outmod() - if(vars.lua_koboldbridge.regeneration_required): - vars.lua_koboldbridge.regeneration_required = False - genout = [] - for i in range(vars.numseqs): - genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) - assert type(genout[-1]["generated_text"]) is str - if(len(genout) == 1): - genresult(genout[0]["generated_text"]) - else: - genselect(genout) - refresh_story() - set_aibusy(0) - emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) - else: - # Apply input formatting & scripts before sending to tokenizer - if(vars.actionmode == 0): - data = applyinputformatting(data) - vars.submission = data - execute_inmod() - data = vars.submission - # Dont append submission if it's a blank/continue action + vars.recentback = False + vars.recentedit = False + vars.actionmode = actionmode + + # "Action" mode + if(actionmode == 1): + data = data.strip().lstrip('>') + data = re.sub(r'\n+', ' ', data) + if(len(data)): + data = f"\n\n> {data}\n" + + # If we're not continuing, store a copy of the raw input if(data != ""): - # Store the result in the Action log - if(len(vars.prompt.strip()) == 0): + vars.lastact = data + + if(not vars.gamestarted): + vars.submission = data + execute_inmod() + data = vars.submission + if(not force_submit and len(data.strip()) == 0): + assert False + # Start the game + vars.gamestarted = True + if(not vars.noai and vars.lua_koboldbridge.generating and not vars.nopromptgen): + # Save this first action as the prompt vars.prompt = data + # Clear the startup text from game screen + emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True) + calcsubmit(data) # Run the first action through the generator + if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0): + data = "" + force_submit = True + continue + emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) + break else: - vars.actions.append(data) - update_story_chunk('last') - - if(not vars.noai and vars.lua_koboldbridge.generating): - # Off to the tokenizer! - calcsubmit(data) - emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) - else: - for i in range(vars.numseqs): - vars.lua_koboldbridge.outputs[i+1] = "" - execute_outmod() - set_aibusy(0) - if(vars.lua_koboldbridge.regeneration_required): + # Save this first action as the prompt + vars.prompt = data + for i in range(vars.numseqs): + vars.lua_koboldbridge.outputs[i+1] = "" + execute_outmod() vars.lua_koboldbridge.regeneration_required = False genout = [] for i in range(vars.numseqs): @@ -1957,9 +2009,73 @@ def actionsubmit(data, actionmode=0, force_submit=False): assert type(genout[-1]["generated_text"]) is str if(len(genout) == 1): genresult(genout[0]["generated_text"]) + if(vars.lua_koboldbridge.restart_sequence is not None): + refresh_story() + data = "" + force_submit = True + continue else: + if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0): + genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"]) + refresh_story() + data = "" + force_submit = True + continue genselect(genout) - emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) + refresh_story() + set_aibusy(0) + emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) + break + else: + # Apply input formatting & scripts before sending to tokenizer + if(vars.actionmode == 0): + data = applyinputformatting(data) + vars.submission = data + execute_inmod() + data = vars.submission + # Dont append submission if it's a blank/continue action + if(data != ""): + # Store the result in the Action log + if(len(vars.prompt.strip()) == 0): + vars.prompt = data + else: + vars.actions.append(data) + update_story_chunk('last') + + if(not vars.noai and vars.lua_koboldbridge.generating): + # Off to the tokenizer! + calcsubmit(data) + if(vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0): + data = "" + force_submit = True + continue + emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) + break + else: + for i in range(vars.numseqs): + vars.lua_koboldbridge.outputs[i+1] = "" + execute_outmod() + vars.lua_koboldbridge.regeneration_required = False + genout = [] + for i in range(vars.numseqs): + genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) + assert type(genout[-1]["generated_text"]) is str + if(len(genout) == 1): + genresult(genout[0]["generated_text"]) + if(vars.lua_koboldbridge.restart_sequence is not None): + data = "" + force_submit = True + continue + else: + if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0): + genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"]) + data = "" + force_submit = True + continue + genselect(genout) + set_aibusy(0) + emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) + break #==================================================================# # @@ -1972,18 +2088,14 @@ def actionretry(data): return # Remove last action if possible and resubmit if(vars.gamestarted if vars.useprompt else len(vars.actions) > 0): - set_aibusy(1) if(not vars.recentback and len(vars.actions) != 0 and len(vars.genseqs) == 0): # Don't pop if we're in the "Select sequence to keep" menu or if there are no non-prompt actions last_key = vars.actions.get_last_key() vars.actions.pop() remove_story_chunk(last_key + 1) - vars.genseqs = [] - vars.submission = "" - execute_inmod() - calcsubmit(vars.submission) - emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) vars.recentback = False vars.recentedit = False + vars.lua_koboldbridge.feedback = None + actionsubmit("", actionmode=vars.actionmode, force_submit=True) elif(not vars.useprompt): emit('from_server', {'cmd': 'errmsg', 'data': "Please enable \"Always Add Prompt\" to retry with your prompt."}) @@ -2223,9 +2335,10 @@ def _generate(txt, minimum, maximum, found_entries): model.kai_scanner_head_length = gen_in.shape[-1] model.kai_scanner_excluded_world_info = found_entries - actions = vars.actions + vars._actions = vars.actions + vars._prompt = vars.prompt if(vars.dynamicscan): - actions = actions.copy() + vars._actions = vars._actions.copy() with torch.no_grad(): already_generated = 0 @@ -2235,7 +2348,7 @@ def _generate(txt, minimum, maximum, found_entries): gen_in, do_sample=True, min_length=minimum, - max_length=maximum-already_generated, + max_length=int(2e9), repetition_penalty=vars.rep_pen, bad_words_ids=vars.badwordsids, use_cache=True, @@ -2257,7 +2370,7 @@ def _generate(txt, minimum, maximum, found_entries): txt = tokenizer.decode(genout[i, -already_generated:]) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) found_entries[i].update(_found_entries) - txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) + txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions) encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device)) max_length = len(max(encoded, key=len)) encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) @@ -2333,7 +2446,10 @@ def generate(txt, minimum, maximum, found_entries=None): if(len(genout) == 1): genresult(genout[0]["generated_text"]) else: - genselect(genout) + if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0): + genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"]) + else: + genselect(genout) # Clear CUDA cache again if using GPU if(vars.hascuda and (vars.usegpu or vars.breakmodel)): @@ -2351,6 +2467,11 @@ def genresult(genout): # Format output before continuing genout = applyoutputformatting(genout) + + vars.lua_koboldbridge.feedback = genout + + if(len(genout) == 0): + return # Add formatted text to Actions array and refresh the game screen if(len(vars.prompt.strip()) == 0): @@ -2383,12 +2504,17 @@ def genselect(genout): def selectsequence(n): if(len(vars.genseqs) == 0): return - vars.actions.append(vars.genseqs[int(n)]["generated_text"]) - update_story_chunk('last') - emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True) + vars.lua_koboldbridge.feedback = vars.genseqs[int(n)]["generated_text"] + if(len(vars.lua_koboldbridge.feedback) != 0): + vars.actions.append(vars.lua_koboldbridge.feedback) + update_story_chunk('last') + emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() if len(vars.actions) else 0}, broadcast=True) emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True) vars.genseqs = [] + if(vars.lua_koboldbridge.restart_sequence is not None): + actionsubmit("", actionmode=vars.actionmode, force_submit=True) + #==================================================================# # Send transformers-style request to ngrok/colab host #==================================================================# @@ -2447,7 +2573,10 @@ def sendtocolab(txt, min, max): seqs = [] for seq in genout: seqs.append({"generated_text": seq}) - genselect(seqs) + if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0): + genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"]) + else: + genselect(genout) # Format output before continuing #genout = applyoutputformatting(getnewcontent(genout)) @@ -2544,7 +2673,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): if(len(genout) == 1): genresult(genout[0]["generated_text"]) else: - genselect(genout) + if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0): + genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"]) + else: + genselect(genout) set_aibusy(0) @@ -2995,7 +3127,7 @@ def deletewifolder(uid): #==================================================================# # Look for WI keys in text to generator #==================================================================# -def checkworldinfo(txt, force_use_txt=False): +def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False): original_txt = txt # Dont go any further if WI is empty @@ -3038,6 +3170,11 @@ def checkworldinfo(txt, force_use_txt=False): wimem = "" found_entries = set() for wi in vars.worldinfo: + if(allowed_entries is not None and wi["uid"] not in allowed_entries): + continue + if(allowed_folders is not None and wi["folder"] not in allowed_folders): + continue + if(wi.get("constant", False)): wimem = wimem + wi["content"] + "\n" found_entries.add(id(wi)) @@ -3840,6 +3977,7 @@ def newGameRequest(): def randomGameRequest(topic): newGameRequest() vars.memory = "You generate the following " + topic + " story concept :" + vars.lua_koboldbridge.feedback = None actionsubmit("", force_submit=True) vars.memory = "" diff --git a/bridge.lua b/bridge.lua index 77178916..3399d8a8 100644 --- a/bridge.lua +++ b/bridge.lua @@ -207,12 +207,14 @@ return function(_python, _bridged) koboldbridge.regeneration_required = false koboldbridge.resend_settings_required = false koboldbridge.generating = true + koboldbridge.restart_sequence = nil koboldbridge.userstate = nil koboldbridge.logits = {} koboldbridge.vocab_size = 0 koboldbridge.generated = {} koboldbridge.generated_cols = 0 koboldbridge.outputs = {} + koboldbridge.feedback = nil ---@type string|nil ---@return nil local function maybe_require_regeneration() @@ -285,12 +287,18 @@ return function(_python, _bridged) return _python.as_attrgetter(bridged.vars.worldinfo_u).get(rawget(self, "_uid")) ~= nil end + ---@param submission? string ---@return string - function KoboldWorldInfoEntry:compute_context() + function KoboldWorldInfoEntry:compute_context(submission) if not check_validity(self) then return "" + elseif submission == nil then + submission = kobold.submission + elseif type(submission) ~= "string" then + error("`compute_context` takes a string or nil as argument #1, but got a " .. type(submission)) + return "" end - return bridged.compute_context({self.uid}) + return bridged.compute_context(submission, {self.uid}, nil) end ---@generic K @@ -335,7 +343,7 @@ return function(_python, _bridged) error("`"..rawget(t, "_name").."."..k.."` must be a "..KoboldWorldInfoEntry_fieldtypes[k].."; you attempted to set it to a "..type(v)) return else - if k ~= "comment" then + if k ~= "comment" and not (t.selective and k == "keysecondary") then maybe_require_regeneration() end bridged.set_attr(t.uid, k, v) @@ -379,26 +387,34 @@ return function(_python, _bridged) return entry end + ---@param submission? string ---@param entries? KoboldWorldInfoEntry|table ---@return string - function KoboldWorldInfoFolder:compute_context(entries) + function KoboldWorldInfoFolder:compute_context(submission, entries) if not check_validity(self) then - return - end - if entries ~= nil and type(entries) ~= "table" or (entries.name ~= nil and entries.name ~= "KoboldWorldInfoEntry") then - error("`compute_context` takes a KoboldWorldInfoEntry, table of KoboldWorldInfoEntries or nil as argument, but got a " .. type(entries)) + return "" + elseif submission == nil then + submission = kobold.submission + elseif type(submission) ~= "string" then + error("`compute_context` takes a string or nil as argument #1, but got a " .. type(submission)) return "" end - if entries.name == "KoboldWorldInfoEntry" then - entries = {entries} - end local _entries - for k, v in pairs(entries) do - if type(v) == "table" and v.name == "KoboldWorldInfoEntry" and (rawget(self, "_name") ~= "KoboldWorldInfoFolder" or self.uid == v.uid) and v:is_valid() then - _entries[k] = v.uid + if entries ~= nil then + if type(entries) ~= "table" or (entries.name ~= nil and entries.name ~= "KoboldWorldInfoEntry") then + error("`compute_context` takes a KoboldWorldInfoEntry, table of KoboldWorldInfoEntries or nil as argument #2, but got a " .. type(entries)) + return "" + elseif entries.name == "KoboldWorldInfoEntry" then + _entries = {entries} + else + for k, v in pairs(entries) do + if type(v) == "table" and v.name == "KoboldWorldInfoEntry" and v:is_valid() then + _entries[k] = v.uid + end + end end end - return bridged.compute_context(_entries) + return bridged.compute_context(submission, _entries, rawget(self, "_uid")) end ---@return boolean @@ -607,10 +623,12 @@ return function(_python, _bridged) if k == "content" then if rawget(t, "_num") == 0 then if bridged.vars.gamestarted then - return bridged.vars.prompt + local prompt = koboldbridge.userstate == "genmod" and bridged.vars._prompt or bridged.vars.prompt + return prompt end end - return _python.as_attrgetter(bridged.vars.actions).get(math.tointeger(rawget(t, "_num")) - 1) + local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions + return _python.as_attrgetter(actions).get(math.tointeger(rawget(t, "_num")) - 1) end end @@ -631,7 +649,8 @@ return function(_python, _bridged) error("Attempted to set the prompt chunk's content to the empty string; this is not allowed") return end - if _k ~= 0 and _python.as_attrgetter(bridged.vars.actions).get(_k-1) == nil then + local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions + if _k ~= 0 and _python.as_attrgetter(actions).get(_k-1) == nil then return end bridged.set_chunk(_k, v) @@ -655,7 +674,8 @@ return function(_python, _bridged) ---@return fun(): KoboldStoryChunk, table, nil function KoboldStory:forward_iter() - local nxt, iterator = _python.iter(bridged.vars.actions) + local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions + local nxt, iterator = _python.iter(actions) local run_once = false local f = function() if not bridged.vars.gamestarted then @@ -682,7 +702,8 @@ return function(_python, _bridged) ---@return fun(): KoboldStoryChunk, table, nil function KoboldStory:reverse_iter() - local nxt, iterator = _python.iter(_python.builtins.reversed(bridged.vars.actions)) + local actions = koboldbridge.userstate == "genmod" and bridged.vars._actions or bridged.vars.actions + local nxt, iterator = _python.iter(_python.builtins.reversed(actions)) local last_run = false local f = function() if not bridged.vars.gamestarted or last_run then @@ -738,21 +759,38 @@ return function(_python, _bridged) ---@class KoboldSettings : KoboldSettings_base ---@field numseqs integer ---@field genamt integer + ---@field anotedepth integer ---@field settemp number ---@field settopp number ---@field settopk integer ---@field settfs number ---@field setreppen number ---@field settknmax integer - ---@field anotedepth integer ---@field setwidepth integer ---@field setuseprompt boolean ---@field setadventure boolean ---@field setdynamicscan boolean + ---@field setnopromptgen boolean + ---@field temp number + ---@field topp number + ---@field topk integer + ---@field tfs number + ---@field reppen number + ---@field tknmax integer + ---@field widepth integer + ---@field useprompt boolean + ---@field adventure boolean + ---@field dynamicscan boolean + ---@field nopromptgen boolean ---@field frmttriminc boolean ---@field frmtrmblln boolean ---@field frmtrmspch boolean ---@field frmtadsnsp boolean + ---@field frmtsingleline boolean + ---@field triminc boolean + ---@field rmblln boolean + ---@field rmspch boolean + ---@field adsnsp boolean ---@field singleline boolean local KoboldSettings = setmetatable({ _name = "KoboldSettings", @@ -783,9 +821,9 @@ return function(_python, _bridged) if type(k) ~= "string" then return end - if k == "genamt" then + if k == "genamt" or k == "output" or k == "setoutput" then return math.tointeger(bridged.get_genamt()), true - elseif k == "numseqs" then + elseif k == "numseqs" or k == "numseq" or k == "setnumseq" then return math.tointeger(bridged.get_numseqs()), true elseif bridged.has_setting(k) then return bridged.get_setting(k), true @@ -796,11 +834,10 @@ return function(_python, _bridged) ---@param t KoboldSettings_base function KoboldSettings_mt.__newindex(t, k, v) - if k == "genamt" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then + if (k == "genamt" or k == "output" or k == "setoutput") and type(v) == "number" and math.tointeger(v) ~= nil and v >= 0 then bridged.set_genamt(v) - maybe_require_regeneration() koboldbridge.resend_settings_required = true - elseif k == "numseqs" and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then + elseif (k == "numseqs" or k == "numseq" or k == "setnumseq") and type(v) == "number" and math.tointeger(v) ~= nil and v >= 1 then if koboldbridge.userstate == "genmod" then error("Cannot set numseqs from a generation modifier") return @@ -808,10 +845,9 @@ return function(_python, _bridged) bridged.set_numseqs(v) koboldbridge.resend_settings_required = true elseif type(k) == "string" and bridged.has_setting(k) and type(v) == type(bridged.get_setting(k)) then - if k == "settknmax" or k == "anotedepth" or k == "setwidepth" or k == "setuseprompt" then + if bridged.set_setting(k, v) == true then maybe_require_regeneration() end - bridged.set_setting(k, v) koboldbridge.resend_settings_required = true end return t @@ -821,7 +857,7 @@ return function(_python, _bridged) --========================================================================== - -- Userscript API: Memory + -- Userscript API: Memory / Author's Note --========================================================================== ---@param t KoboldLib @@ -842,6 +878,24 @@ return function(_python, _bridged) bridged.set_memory(v) end + ---@param t KoboldLib + ---@return string + function KoboldLib_getters.authorsnote(t) + return bridged.get_authorsnote() + end + + ---@param t KoboldLib + ---@param v string + ---@return KoboldLib + function KoboldLib_setters.authorsnote(t, v) + if type(v) ~= "string" then + error("`KoboldLib.authorsnote` must be a string; you attempted to set it to a "..type(v)) + return + end + maybe_require_regeneration() + bridged.set_authorsnote(v) + end + --========================================================================== -- Userscript API: User-submitted text (after applying input formatting) @@ -992,7 +1046,7 @@ return function(_python, _bridged) error("Cannot write to `KoboldLib.logits` from outside of a generation modifer") return elseif type(v) ~= "table" then - error("`KoboldLib.logits` must be a 2D list (table) of numbers; you attempted to set it to a " .. type(v)) + error("`KoboldLib.logits` must be a 2D array of numbers; you attempted to set it to a " .. type(v)) return end koboldbridge.logits = v @@ -1009,6 +1063,8 @@ return function(_python, _bridged) local backend = kobold.modelbackend if backend == "readonly" or backend == "api" then return 0 + elseif koboldbridge.userstate == "outmod" then + return koboldbridge.num_outputs end return kobold.settings.numseqs end @@ -1058,7 +1114,7 @@ return function(_python, _bridged) error("Cannot write to `KoboldLib.generated` from outside of a generation modifier") return elseif type(v) ~= "table" then - error("`KoboldLib.generated` must be a 2D list (table) of integers; you attempted to set it to a " .. type(v)) + error("`KoboldLib.generated` must be a 2D array of integers; you attempted to set it to a " .. type(v)) return end koboldbridge.generated = v @@ -1072,14 +1128,13 @@ return function(_python, _bridged) ---@param t KoboldLib ---@return integer function KoboldLib_getters.num_outputs(t) - local backend = kobold.modelbackend - if backend == "readonly" then - return 0 - end local model = kobold.model if model == "OAI" or model == "InferKit" then return 1 end + if koboldbridge.userstate == "outmod" then + return koboldbridge.num_outputs + end return kobold.settings.numseqs end @@ -1105,7 +1160,7 @@ return function(_python, _bridged) error("Cannot write to `KoboldLib.generated` from outside of an output modifier") return elseif type(v) ~= "table" then - error("`KoboldLib.generated` must be a list (table) of strings; you attempted to set it to a " .. type(v)) + error("`KoboldLib.generated` must be a 1D array of strings; you attempted to set it to a " .. type(v)) return end koboldbridge.outputs = v @@ -1158,6 +1213,41 @@ return function(_python, _bridged) koboldbridge.generating = false end + ---@param sequence? integer + ---@return nil + function kobold.restart_generation(sequence) + if sequence == nil then + sequence = 0 + end + sequence_type = type(sequence) + sequence = math.tointeger(sequence) + if sequence_type ~= "number" then + error("`kobold.restart_generation` takes an integer greater than or equal to 0 or nil as argument, but got a " .. sequence_type) + return + elseif sequence < 0 then + error("`kobold.restart_generation` takes an integer greater than or equal to 0 or nil as argument, but got `" .. sequence .. "`") + return + end + if koboldbridge.userstate ~= "outmod" then + error("Can only call `kobold.restart_generation()` from an output modifier") + return + end + koboldbridge.restart_sequence = sequence + end + + ---@param t KoboldCoreLib + ---@return string + function KoboldLib_getters.feedback(t) + return koboldbridge.feedback + end + + ---@param t KoboldCoreLib + ---@param v string + ---@return KoboldCoreLib + function KoboldLib_setters.feedback(t, v) + error("`KoboldLib.feedback` is a read-only attribute") + end + --========================================================================== -- Core script API @@ -1383,17 +1473,17 @@ return function(_python, _bridged) end local function redirected_print(...) - local args = {...} - for k, v in ipairs(args) do - args[k] = tostring(v) + local args = table.pack(...) + for i = 1, args.n do + args[i] = tostring(args[i]) end bridged.print(table.concat(args, "\t")) end local function redirected_warn(...) - local args = {...} - for k, v in ipairs(args) do - args[k] = tostring(v) + local args = table.pack(...) + for i = 1, args.n do + args[i] = tostring(args[i]) end bridged.warn(table.concat(args, "\t")) end @@ -1579,8 +1669,8 @@ return function(_python, _bridged) koboldbridge.num_userscripts = 0 for i, filename in _python.enumerate(filenames) do bridged.load_callback(filename, modulenames[i]) - ---@type KoboldUserScript koboldbridge.logging_name = modulenames[i] + ---@type KoboldUserScript local _userscript = old_loadfile(join_folder_and_filename(bridged.userscript_path, filename), "t", koboldbridge.get_universe(filename))() koboldbridge.logging_name = nil local userscript = deepcopy(KoboldUserScriptModule) @@ -1606,9 +1696,10 @@ return function(_python, _bridged) function koboldbridge.execute_inmod() local r + koboldbridge.restart_sequence = nil + koboldbridge.userstate = "inmod" koboldbridge.regeneration_required = false koboldbridge.generating = true - koboldbridge.userstate = "inmod" koboldbridge.generated_cols = 0 koboldbridge.generated = {} for i = 1, kobold.settings.numseqs do @@ -1665,13 +1756,14 @@ return function(_python, _bridged) local r koboldbridge.generating = false koboldbridge.userstate = "outmod" + koboldbridge.num_outputs = kobold.settings.numseqs if koboldbridge.outmod ~= nil then local _outputs = deepcopy(koboldbridge.outputs) r = koboldbridge.outmod() setmetatable(koboldbridge.outputs, nil) for k, v in old_next, koboldbridge.outputs, nil do if type(v) ~= "string" then - error("`kobold.outputs` must be a 1D list of strings, but found a non-string element at index " .. k) + error("`kobold.outputs` must be a 1D array of strings, but found a non-string element at index " .. k) return r end if v ~= _outputs[k] then diff --git a/cores/default.lua b/cores/default.lua index c6ad6d33..2b56c51c 100644 --- a/cores/default.lua +++ b/cores/default.lua @@ -1,6 +1,6 @@ -- Default core script --- Runs all input modifiers and generation modifiers in forward order, and --- runs all output modifiers in reverse order +-- Runs all generation modifiers and output modifiers in forward order, and +-- runs all input modifiers in reverse order kobold, koboldcore = require("bridge")() -- This line is optional and is only for EmmyLua type annotations diff --git a/static/application.js b/static/application.js index dc289a73..3296348e 100644 --- a/static/application.js +++ b/static/application.js @@ -646,6 +646,7 @@ function hideMessage() { } function showWaitAnimation() { + hideWaitAnimation(); $("#inputrowright").append(""); } diff --git a/templates/index.html b/templates/index.html index dcd31bfe..8451bbae 100644 --- a/templates/index.html +++ b/templates/index.html @@ -7,10 +7,10 @@ - +