Add basic support for some of the quick stoppers

This commit is contained in:
somebody
2023-07-21 13:27:30 -05:00
parent 6cf63f781a
commit 3a43b254b8
4 changed files with 137 additions and 34 deletions

View File

@@ -12,6 +12,8 @@ import random
import shutil import shutil
import eventlet import eventlet
from modeling.inference_model import GenerationMode
eventlet.monkey_patch(all=True, thread=False, os=False) eventlet.monkey_patch(all=True, thread=False, os=False)
import os, inspect, contextlib, pickle import os, inspect, contextlib, pickle
os.system("") os.system("")
@@ -3266,7 +3268,16 @@ def check_for_backend_compilation():
break break
koboldai_vars.checking = False koboldai_vars.checking = False
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False): def actionsubmit(
data,
actionmode=0,
force_submit=False,
force_prompt_gen=False,
disable_recentrng=False,
no_generate=False,
ignore_aibusy=False,
gen_mode=GenerationMode.STANDARD
):
# Ignore new submissions if the AI is currently busy # Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy): if(koboldai_vars.aibusy):
return return
@@ -3424,7 +3435,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating): if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
# Off to the tokenizer! # Off to the tokenizer!
calcsubmit("") calcsubmit("", gen_mode=gen_mode)
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0): if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = "" data = ""
force_submit = True force_submit = True
@@ -3779,7 +3790,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
#==================================================================# #==================================================================#
# Take submitted text and build the text to be given to generator # Take submitted text and build the text to be given to generator
#==================================================================# #==================================================================#
def calcsubmit(txt): def calcsubmit(txt, gen_mode=GenerationMode.STANDARD):
anotetxt = "" # Placeholder for Author's Note text anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth
@@ -3821,7 +3832,7 @@ def calcsubmit(txt):
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time)) logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
start_time = time.time() start_time = time.time()
generate(subtxt, min, max, found_entries) generate(subtxt, min, max, found_entries, gen_mode=gen_mode)
logger.debug("Submit: generate time {}s".format(time.time()-start_time)) logger.debug("Submit: generate time {}s".format(time.time()-start_time))
attention_bias.attention_bias = None attention_bias.attention_bias = None
@@ -3889,7 +3900,7 @@ class HordeException(Exception):
# Send text to generator and deal with output # Send text to generator and deal with output
#==================================================================# #==================================================================#
def generate(txt, minimum, maximum, found_entries=None): def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD):
koboldai_vars.generated_tkns = 0 koboldai_vars.generated_tkns = 0
if(found_entries is None): if(found_entries is None):
@@ -3911,7 +3922,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator # Submit input text to generator
try: try:
start_time = time.time() start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries) genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):
@@ -6168,22 +6179,43 @@ def UI_2_delete_option(data):
@socketio.on('submit') @socketio.on('submit')
@logger.catch @logger.catch
def UI_2_submit(data): def UI_2_submit(data):
if not koboldai_vars.noai and data['theme'] != "": if not koboldai_vars.noai and data['theme']:
# Random prompt generation
logger.debug("doing random prompt") logger.debug("doing random prompt")
memory = koboldai_vars.memory memory = koboldai_vars.memory
koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme']) koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme'])
koboldai_vars.lua_koboldbridge.feedback = None koboldai_vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True, force_prompt_gen=True) actionsubmit("", force_submit=True, force_prompt_gen=True)
koboldai_vars.memory = memory koboldai_vars.memory = memory
else: return
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options() logger.debug("doing normal input")
koboldai_vars.lua_koboldbridge.feedback = None koboldai_vars.actions.clear_unused_options()
koboldai_vars.recentrng = koboldai_vars.recentrngm = None koboldai_vars.lua_koboldbridge.feedback = None
if koboldai_vars.actions.action_count == -1: koboldai_vars.recentrng = koboldai_vars.recentrngm = None
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
else: gen_mode_name = data.get("gen_mode", None)
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode) gen_mode = {
# If we don't have a gen mode, or it's None (the default), just do a
# normal submission.
None: GenerationMode.STANDARD,
# NOTE: forever should be a no-op on models that don't support
# interrupting generation. This should be conveyed to the user by
# graying out the option in the context menu.
"forever": GenerationMode.FOREVER,
# The following gen modes require stopping criteria to be respected by
# the backend:
"until_eos": GenerationMode.UNTIL_EOS,
"until_newline": GenerationMode.UNTIL_NEWLINE,
"until_sentence_end": GenerationMode.UNTIL_SENTENCE_END,
}.get(gen_mode_name, None)
if not gen_mode:
raise RuntimeError(f"Unknown gen_mode '{gen_mode_name}'")
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode)
#==================================================================# #==================================================================#
# Event triggered when user clicks the submit button # Event triggered when user clicks the submit button

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import time import time
from typing import List, Optional, Union from typing import List, Optional, Union
from enum import Enum
from logger import logger from logger import logger
import torch import torch
@@ -12,6 +14,7 @@ from transformers import (
GPT2Tokenizer, GPT2Tokenizer,
AutoTokenizer, AutoTokenizer,
) )
from modeling.stoppers import Stoppers
from modeling.tokenizer import GenericTokenizer from modeling.tokenizer import GenericTokenizer
from modeling import logits_processors from modeling import logits_processors
@@ -154,6 +157,12 @@ class ModelCapabilities:
# Some models need to warm up the TPU before use # Some models need to warm up the TPU before use
uses_tpu: bool = False uses_tpu: bool = False
class GenerationMode(Enum):
STANDARD = 0
FOREVER = 1
UNTIL_EOS = 2
UNTIL_NEWLINE = 3
UNTIL_SENTENCE_END = 4
class InferenceModel: class InferenceModel:
"""Root class for all models.""" """Root class for all models."""
@@ -256,6 +265,7 @@ class InferenceModel:
self, self,
text: list, text: list,
found_entries: set, found_entries: set,
gen_mode: GenerationMode = GenerationMode.STANDARD,
): ):
"""Generate story text. Heavily tied to story-specific parameters; if """Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`. you are making a new generation-based feature, consider `generate_raw()`.
@@ -263,6 +273,7 @@ class InferenceModel:
Args: Args:
text (list): Encoded input tokens text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI found_entries (set): Entries found for Dynamic WI
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
Raises: Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
@@ -358,6 +369,7 @@ class InferenceModel:
seed=utils.koboldai_vars.seed seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism if utils.koboldai_vars.full_determinism
else None, else None,
gen_mode=gen_mode
) )
logger.debug( logger.debug(
"core_generate: run raw_generate pass {} {}s".format( "core_generate: run raw_generate pass {} {}s".format(
@@ -532,6 +544,7 @@ class InferenceModel:
found_entries: set = (), found_entries: set = (),
tpu_dynamic_inference: bool = False, tpu_dynamic_inference: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
gen_mode: GenerationMode = GenerationMode.STANDARD,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -547,6 +560,7 @@ class InferenceModel:
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False. is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False. single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to (). found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
Raises: Raises:
ValueError: If prompt type is weird ValueError: If prompt type is weird
@@ -568,6 +582,21 @@ class InferenceModel:
"wi_scanner_excluded_keys", set() "wi_scanner_excluded_keys", set()
) )
temp_stoppers = []
if gen_mode == GenerationMode.FOREVER:
raise NotImplementedError()
elif gen_mode == GenerationMode.UNTIL_EOS:
# Still need to unban
raise NotImplementedError()
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper)
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
temp_stoppers.append(Stoppers.sentence_end_stopper)
self.stopper_hooks += temp_stoppers
utils.koboldai_vars.inference_config.do_core = is_core utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {})) gen_settings = GenerationSettings(*(generation_settings or {}))
@@ -604,6 +633,9 @@ class InferenceModel:
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second." f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
) )
for stopper in temp_stoppers:
self.stopper_hooks.remove(stopper)
return result return result
def generate( def generate(

View File

@@ -3,15 +3,12 @@ from __future__ import annotations
import torch import torch
import utils import utils
from modeling.inference_model import ( from modeling import inference_model
InferenceModel,
)
class Stoppers: class Stoppers:
@staticmethod @staticmethod
def core_stopper( def core_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.inference_config.do_core: if not utils.koboldai_vars.inference_config.do_core:
@@ -62,7 +59,7 @@ class Stoppers:
@staticmethod @staticmethod
def dynamic_wi_scanner( def dynamic_wi_scanner(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi: if not utils.koboldai_vars.inference_config.do_dynamic_wi:
@@ -93,7 +90,7 @@ class Stoppers:
@staticmethod @staticmethod
def chat_mode_stopper( def chat_mode_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.chatmode: if not utils.koboldai_vars.chatmode:
@@ -118,7 +115,7 @@ class Stoppers:
@staticmethod @staticmethod
def stop_sequence_stopper( def stop_sequence_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
@@ -145,14 +142,22 @@ class Stoppers:
@staticmethod @staticmethod
def singleline_stopper( def singleline_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline.""" """Stop on occurances of newlines **if singleline is enabled**."""
# It might be better just to do this further up the line
if not utils.koboldai_vars.singleline: if not utils.koboldai_vars.singleline:
return False return False
return Stoppers.newline_stopper(model, input_ids)
@staticmethod
def newline_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stop on occurances of newlines."""
# Keep track of presence of newlines in each sequence; we cannot stop a # Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain # batch member individually, so we must wait for all of them to contain
# a newline. # a newline.
@@ -167,3 +172,30 @@ class Stoppers:
del model.gen_state["newline_in_sequence"] del model.gen_state["newline_in_sequence"]
return True return True
return False return False
@staticmethod
def sentence_end_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stops at the end of sentences."""
# TODO: Make this more robust
SENTENCE_ENDS = [".", "?", "!"]
# We need to keep track of stopping for each batch, since we can't stop
# one individually.
if "sentence_end_in_sequence" not in model.gen_state:
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids):
decoded = model.tokenizer.decode(batch_sequence[-1])
for end in SENTENCE_ENDS:
if end in decoded:
model.gen_state["sentence_end_sequence"][sequence_idx] = True
break
if all(model.gen_state["sentence_end_sequence"]):
del model.gen_state["sentence_end_sequence"]
return True
return False

View File

@@ -149,13 +149,13 @@ const context_menu_actions = {
{label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage}, {label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage},
], ],
"submit-button": [ "submit-button": [
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: function(){}}, {label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()},
null, null,
{label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, {label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("forever")},
{label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, {label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_eos")},
null, null,
{label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, {label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_newline")},
{label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}}, {label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_sentence_end")},
], ],
"undo-button": [ "undo-button": [
{label: "Undo", icon: "undo", enabledOn: "ALWAYS", click: function(){}}, {label: "Undo", icon: "undo", enabledOn: "ALWAYS", click: function(){}},
@@ -256,10 +256,17 @@ function disconnect() {
document.getElementById("disconnect_message").classList.remove("hidden"); document.getElementById("disconnect_message").classList.remove("hidden");
} }
function storySubmit() { function storySubmit(genMode=null) {
const textInput = document.getElementById("input_text");
const themeInput = document.getElementById("themetext");
disruptStoryState(); disruptStoryState();
socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value}); socket.emit('submit', {
document.getElementById('input_text').value = ''; data: textInput.value,
theme: themeInput.value,
gen_mode: genMode,
});
textInput.value = '';
document.getElementById('themetext').value = ''; document.getElementById('themetext').value = '';
} }