From 3a43b254b86733a637a2286bf8a3c9421674771a Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 21 Jul 2023 13:27:30 -0500 Subject: [PATCH] Add basic support for some of the quick stoppers --- aiserver.py | 64 +++++++++++++++++++++++++++---------- modeling/inference_model.py | 32 +++++++++++++++++++ modeling/stoppers.py | 52 ++++++++++++++++++++++++------ static/koboldai.js | 23 ++++++++----- 4 files changed, 137 insertions(+), 34 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0aa9bd4c..1cb9146e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -12,6 +12,8 @@ import random import shutil import eventlet +from modeling.inference_model import GenerationMode + eventlet.monkey_patch(all=True, thread=False, os=False) import os, inspect, contextlib, pickle os.system("") @@ -3266,7 +3268,16 @@ def check_for_backend_compilation(): break koboldai_vars.checking = False -def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False): +def actionsubmit( + data, + actionmode=0, + force_submit=False, + force_prompt_gen=False, + disable_recentrng=False, + no_generate=False, + ignore_aibusy=False, + gen_mode=GenerationMode.STANDARD +): # Ignore new submissions if the AI is currently busy if(koboldai_vars.aibusy): return @@ -3424,7 +3435,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating): # Off to the tokenizer! - calcsubmit("") + calcsubmit("", gen_mode=gen_mode) if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0): data = "" force_submit = True @@ -3779,7 +3790,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, #==================================================================# # Take submitted text and build the text to be given to generator #==================================================================# -def calcsubmit(txt): +def calcsubmit(txt, gen_mode=GenerationMode.STANDARD): anotetxt = "" # Placeholder for Author's Note text forceanote = False # In case we don't have enough actions to hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth @@ -3821,7 +3832,7 @@ def calcsubmit(txt): logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time)) start_time = time.time() - generate(subtxt, min, max, found_entries) + generate(subtxt, min, max, found_entries, gen_mode=gen_mode) logger.debug("Submit: generate time {}s".format(time.time()-start_time)) attention_bias.attention_bias = None @@ -3889,7 +3900,7 @@ class HordeException(Exception): # Send text to generator and deal with output #==================================================================# -def generate(txt, minimum, maximum, found_entries=None): +def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD): koboldai_vars.generated_tkns = 0 if(found_entries is None): @@ -3911,7 +3922,7 @@ def generate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: start_time = time.time() - genout, already_generated = tpool.execute(model.core_generate, txt, found_entries) + genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) except Exception as e: if(issubclass(type(e), lupa.LuaError)): @@ -6168,22 +6179,43 @@ def UI_2_delete_option(data): @socketio.on('submit') @logger.catch def UI_2_submit(data): - if not koboldai_vars.noai and data['theme'] != "": + if not koboldai_vars.noai and data['theme']: + # Random prompt generation logger.debug("doing random prompt") memory = koboldai_vars.memory koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme']) koboldai_vars.lua_koboldbridge.feedback = None actionsubmit("", force_submit=True, force_prompt_gen=True) koboldai_vars.memory = memory - else: - logger.debug("doing normal input") - koboldai_vars.actions.clear_unused_options() - koboldai_vars.lua_koboldbridge.feedback = None - koboldai_vars.recentrng = koboldai_vars.recentrngm = None - if koboldai_vars.actions.action_count == -1: - actionsubmit(data['data'], actionmode=koboldai_vars.actionmode) - else: - actionsubmit(data['data'], actionmode=koboldai_vars.actionmode) + return + + logger.debug("doing normal input") + koboldai_vars.actions.clear_unused_options() + koboldai_vars.lua_koboldbridge.feedback = None + koboldai_vars.recentrng = koboldai_vars.recentrngm = None + + gen_mode_name = data.get("gen_mode", None) + gen_mode = { + # If we don't have a gen mode, or it's None (the default), just do a + # normal submission. + None: GenerationMode.STANDARD, + + # NOTE: forever should be a no-op on models that don't support + # interrupting generation. This should be conveyed to the user by + # graying out the option in the context menu. + "forever": GenerationMode.FOREVER, + + # The following gen modes require stopping criteria to be respected by + # the backend: + "until_eos": GenerationMode.UNTIL_EOS, + "until_newline": GenerationMode.UNTIL_NEWLINE, + "until_sentence_end": GenerationMode.UNTIL_SENTENCE_END, + }.get(gen_mode_name, None) + + if not gen_mode: + raise RuntimeError(f"Unknown gen_mode '{gen_mode_name}'") + + actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode) #==================================================================# # Event triggered when user clicks the submit button diff --git a/modeling/inference_model.py b/modeling/inference_model.py index a2d4fa63..1d285576 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -3,6 +3,8 @@ from __future__ import annotations from dataclasses import dataclass import time from typing import List, Optional, Union + +from enum import Enum from logger import logger import torch @@ -12,6 +14,7 @@ from transformers import ( GPT2Tokenizer, AutoTokenizer, ) +from modeling.stoppers import Stoppers from modeling.tokenizer import GenericTokenizer from modeling import logits_processors @@ -154,6 +157,12 @@ class ModelCapabilities: # Some models need to warm up the TPU before use uses_tpu: bool = False +class GenerationMode(Enum): + STANDARD = 0 + FOREVER = 1 + UNTIL_EOS = 2 + UNTIL_NEWLINE = 3 + UNTIL_SENTENCE_END = 4 class InferenceModel: """Root class for all models.""" @@ -256,6 +265,7 @@ class InferenceModel: self, text: list, found_entries: set, + gen_mode: GenerationMode = GenerationMode.STANDARD, ): """Generate story text. Heavily tied to story-specific parameters; if you are making a new generation-based feature, consider `generate_raw()`. @@ -263,6 +273,7 @@ class InferenceModel: Args: text (list): Encoded input tokens found_entries (set): Entries found for Dynamic WI + gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD Raises: RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check @@ -358,6 +369,7 @@ class InferenceModel: seed=utils.koboldai_vars.seed if utils.koboldai_vars.full_determinism else None, + gen_mode=gen_mode ) logger.debug( "core_generate: run raw_generate pass {} {}s".format( @@ -532,6 +544,7 @@ class InferenceModel: found_entries: set = (), tpu_dynamic_inference: bool = False, seed: Optional[int] = None, + gen_mode: GenerationMode = GenerationMode.STANDARD, **kwargs, ) -> GenerationResult: """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. @@ -547,6 +560,7 @@ class InferenceModel: is_core (bool, optional): Whether this generation is a core story generation. Defaults to False. single_line (bool, optional): Generate one line only.. Defaults to False. found_entries (set, optional): Entries found for Dynamic WI. Defaults to (). + gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD. Raises: ValueError: If prompt type is weird @@ -568,6 +582,21 @@ class InferenceModel: "wi_scanner_excluded_keys", set() ) + temp_stoppers = [] + + if gen_mode == GenerationMode.FOREVER: + raise NotImplementedError() + elif gen_mode == GenerationMode.UNTIL_EOS: + # Still need to unban + raise NotImplementedError() + elif gen_mode == GenerationMode.UNTIL_NEWLINE: + # TODO: Look into replacing `single_line` with `generation_mode` + temp_stoppers.append(Stoppers.newline_stopper) + elif gen_mode == GenerationMode.UNTIL_SENTENCE_END: + temp_stoppers.append(Stoppers.sentence_end_stopper) + + self.stopper_hooks += temp_stoppers + utils.koboldai_vars.inference_config.do_core = is_core gen_settings = GenerationSettings(*(generation_settings or {})) @@ -604,6 +633,9 @@ class InferenceModel: f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second." ) + for stopper in temp_stoppers: + self.stopper_hooks.remove(stopper) + return result def generate( diff --git a/modeling/stoppers.py b/modeling/stoppers.py index 94c09e85..02c1ce48 100644 --- a/modeling/stoppers.py +++ b/modeling/stoppers.py @@ -3,15 +3,12 @@ from __future__ import annotations import torch import utils -from modeling.inference_model import ( - InferenceModel, -) - +from modeling import inference_model class Stoppers: @staticmethod def core_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: if not utils.koboldai_vars.inference_config.do_core: @@ -62,7 +59,7 @@ class Stoppers: @staticmethod def dynamic_wi_scanner( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: if not utils.koboldai_vars.inference_config.do_dynamic_wi: @@ -93,7 +90,7 @@ class Stoppers: @staticmethod def chat_mode_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: if not utils.koboldai_vars.chatmode: @@ -118,7 +115,7 @@ class Stoppers: @staticmethod def stop_sequence_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: @@ -145,14 +142,22 @@ class Stoppers: @staticmethod def singleline_stopper( - model: InferenceModel, + model: inference_model.InferenceModel, input_ids: torch.LongTensor, ) -> bool: - """If singleline mode is enabled, it's pointless to generate output beyond the first newline.""" + """Stop on occurances of newlines **if singleline is enabled**.""" + # It might be better just to do this further up the line if not utils.koboldai_vars.singleline: return False + return Stoppers.newline_stopper(model, input_ids) + @staticmethod + def newline_stopper( + model: inference_model.InferenceModel, + input_ids: torch.LongTensor, + ) -> bool: + """Stop on occurances of newlines.""" # Keep track of presence of newlines in each sequence; we cannot stop a # batch member individually, so we must wait for all of them to contain # a newline. @@ -167,3 +172,30 @@ class Stoppers: del model.gen_state["newline_in_sequence"] return True return False + + @staticmethod + def sentence_end_stopper( + model: inference_model.InferenceModel, + input_ids: torch.LongTensor, + ) -> bool: + """Stops at the end of sentences.""" + + # TODO: Make this more robust + SENTENCE_ENDS = [".", "?", "!"] + + # We need to keep track of stopping for each batch, since we can't stop + # one individually. + if "sentence_end_in_sequence" not in model.gen_state: + model.gen_state["sentence_end_sequence"] = [False] * len(input_ids) + + for sequence_idx, batch_sequence in enumerate(input_ids): + decoded = model.tokenizer.decode(batch_sequence[-1]) + for end in SENTENCE_ENDS: + if end in decoded: + model.gen_state["sentence_end_sequence"][sequence_idx] = True + break + + if all(model.gen_state["sentence_end_sequence"]): + del model.gen_state["sentence_end_sequence"] + return True + return False \ No newline at end of file diff --git a/static/koboldai.js b/static/koboldai.js index 75563df2..320ec927 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -149,13 +149,13 @@ const context_menu_actions = { {label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage}, ], "submit-button": [ - {label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: function(){}}, + {label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()}, null, - {label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, - {label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, + {label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("forever")}, + {label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_eos")}, null, - {label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, - {label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, + {label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_newline")}, + {label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_sentence_end")}, ], "undo-button": [ {label: "Undo", icon: "undo", enabledOn: "ALWAYS", click: function(){}}, @@ -256,10 +256,17 @@ function disconnect() { document.getElementById("disconnect_message").classList.remove("hidden"); } -function storySubmit() { +function storySubmit(genMode=null) { + const textInput = document.getElementById("input_text"); + const themeInput = document.getElementById("themetext"); disruptStoryState(); - socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value}); - document.getElementById('input_text').value = ''; + socket.emit('submit', { + data: textInput.value, + theme: themeInput.value, + gen_mode: genMode, + }); + + textInput.value = ''; document.getElementById('themetext').value = ''; }