Add basic support for some of the quick stoppers

This commit is contained in:
somebody
2023-07-21 13:27:30 -05:00
parent 6cf63f781a
commit 3a43b254b8
4 changed files with 137 additions and 34 deletions

View File

@@ -12,6 +12,8 @@ import random
import shutil
import eventlet
from modeling.inference_model import GenerationMode
eventlet.monkey_patch(all=True, thread=False, os=False)
import os, inspect, contextlib, pickle
os.system("")
@@ -3266,7 +3268,16 @@ def check_for_backend_compilation():
break
koboldai_vars.checking = False
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False):
def actionsubmit(
data,
actionmode=0,
force_submit=False,
force_prompt_gen=False,
disable_recentrng=False,
no_generate=False,
ignore_aibusy=False,
gen_mode=GenerationMode.STANDARD
):
# Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy):
return
@@ -3424,7 +3435,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit("")
calcsubmit("", gen_mode=gen_mode)
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = ""
force_submit = True
@@ -3779,7 +3790,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
def calcsubmit(txt, gen_mode=GenerationMode.STANDARD):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
@@ -3821,7 +3832,7 @@ def calcsubmit(txt):
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
start_time = time.time()
generate(subtxt, min, max, found_entries)
generate(subtxt, min, max, found_entries, gen_mode=gen_mode)
logger.debug("Submit: generate time {}s".format(time.time()-start_time))
attention_bias.attention_bias = None
@@ -3889,7 +3900,7 @@ class HordeException(Exception):
# Send text to generator and deal with output
#==================================================================#
def generate(txt, minimum, maximum, found_entries=None):
def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD):
koboldai_vars.generated_tkns = 0
if(found_entries is None):
@@ -3911,7 +3922,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator
try:
start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
@@ -6168,22 +6179,43 @@ def UI_2_delete_option(data):
@socketio.on('submit')
@logger.catch
def UI_2_submit(data):
if not koboldai_vars.noai and data['theme'] != "":
if not koboldai_vars.noai and data['theme']:
# Random prompt generation
logger.debug("doing random prompt")
memory = koboldai_vars.memory
koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme'])
koboldai_vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True, force_prompt_gen=True)
koboldai_vars.memory = memory
else:
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options()
koboldai_vars.lua_koboldbridge.feedback = None
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
if koboldai_vars.actions.action_count == -1:
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
else:
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
return
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options()
koboldai_vars.lua_koboldbridge.feedback = None
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
gen_mode_name = data.get("gen_mode", None)
gen_mode = {
# If we don't have a gen mode, or it's None (the default), just do a
# normal submission.
None: GenerationMode.STANDARD,
# NOTE: forever should be a no-op on models that don't support
# interrupting generation. This should be conveyed to the user by
# graying out the option in the context menu.
"forever": GenerationMode.FOREVER,
# The following gen modes require stopping criteria to be respected by
# the backend:
"until_eos": GenerationMode.UNTIL_EOS,
"until_newline": GenerationMode.UNTIL_NEWLINE,
"until_sentence_end": GenerationMode.UNTIL_SENTENCE_END,
}.get(gen_mode_name, None)
if not gen_mode:
raise RuntimeError(f"Unknown gen_mode '{gen_mode_name}'")
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode)
#==================================================================#
# Event triggered when user clicks the submit button

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
import time
from typing import List, Optional, Union
from enum import Enum
from logger import logger
import torch
@@ -12,6 +14,7 @@ from transformers import (
GPT2Tokenizer,
AutoTokenizer,
)
from modeling.stoppers import Stoppers
from modeling.tokenizer import GenericTokenizer
from modeling import logits_processors
@@ -154,6 +157,12 @@ class ModelCapabilities:
# Some models need to warm up the TPU before use
uses_tpu: bool = False
class GenerationMode(Enum):
STANDARD = 0
FOREVER = 1
UNTIL_EOS = 2
UNTIL_NEWLINE = 3
UNTIL_SENTENCE_END = 4
class InferenceModel:
"""Root class for all models."""
@@ -256,6 +265,7 @@ class InferenceModel:
self,
text: list,
found_entries: set,
gen_mode: GenerationMode = GenerationMode.STANDARD,
):
"""Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`.
@@ -263,6 +273,7 @@ class InferenceModel:
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
@@ -358,6 +369,7 @@ class InferenceModel:
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
gen_mode=gen_mode
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
@@ -532,6 +544,7 @@ class InferenceModel:
found_entries: set = (),
tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
gen_mode: GenerationMode = GenerationMode.STANDARD,
**kwargs,
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -547,6 +560,7 @@ class InferenceModel:
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
Raises:
ValueError: If prompt type is weird
@@ -568,6 +582,21 @@ class InferenceModel:
"wi_scanner_excluded_keys", set()
)
temp_stoppers = []
if gen_mode == GenerationMode.FOREVER:
raise NotImplementedError()
elif gen_mode == GenerationMode.UNTIL_EOS:
# Still need to unban
raise NotImplementedError()
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper)
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
temp_stoppers.append(Stoppers.sentence_end_stopper)
self.stopper_hooks += temp_stoppers
utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {}))
@@ -604,6 +633,9 @@ class InferenceModel:
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
)
for stopper in temp_stoppers:
self.stopper_hooks.remove(stopper)
return result
def generate(

View File

@@ -3,15 +3,12 @@ from __future__ import annotations
import torch
import utils
from modeling.inference_model import (
InferenceModel,
)
from modeling import inference_model
class Stoppers:
@staticmethod
def core_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_core:
@@ -62,7 +59,7 @@ class Stoppers:
@staticmethod
def dynamic_wi_scanner(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
@@ -93,7 +90,7 @@ class Stoppers:
@staticmethod
def chat_mode_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.chatmode:
@@ -118,7 +115,7 @@ class Stoppers:
@staticmethod
def stop_sequence_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
@@ -145,14 +142,22 @@ class Stoppers:
@staticmethod
def singleline_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
"""Stop on occurances of newlines **if singleline is enabled**."""
# It might be better just to do this further up the line
if not utils.koboldai_vars.singleline:
return False
return Stoppers.newline_stopper(model, input_ids)
@staticmethod
def newline_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stop on occurances of newlines."""
# Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain
# a newline.
@@ -167,3 +172,30 @@ class Stoppers:
del model.gen_state["newline_in_sequence"]
return True
return False
@staticmethod
def sentence_end_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stops at the end of sentences."""
# TODO: Make this more robust
SENTENCE_ENDS = [".", "?", "!"]
# We need to keep track of stopping for each batch, since we can't stop
# one individually.
if "sentence_end_in_sequence" not in model.gen_state:
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids):
decoded = model.tokenizer.decode(batch_sequence[-1])
for end in SENTENCE_ENDS:
if end in decoded:
model.gen_state["sentence_end_sequence"][sequence_idx] = True
break
if all(model.gen_state["sentence_end_sequence"]):
del model.gen_state["sentence_end_sequence"]
return True
return False

View File

@@ -149,13 +149,13 @@ const context_menu_actions = {
{label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage},
],
"submit-button": [
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: function(){}},
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()},
null,
{label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
{label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
{label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("forever")},
{label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_eos")},
null,
{label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
{label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
{label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_newline")},
{label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_sentence_end")},
],
"undo-button": [
{label: "Undo", icon: "undo", enabledOn: "ALWAYS", click: function(){}},
@@ -256,10 +256,17 @@ function disconnect() {
document.getElementById("disconnect_message").classList.remove("hidden");
}
function storySubmit() {
function storySubmit(genMode=null) {
const textInput = document.getElementById("input_text");
const themeInput = document.getElementById("themetext");
disruptStoryState();
socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value});
document.getElementById('input_text').value = '';
socket.emit('submit', {
data: textInput.value,
theme: themeInput.value,
gen_mode: genMode,
});
textInput.value = '';
document.getElementById('themetext').value = '';
}