mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add basic support for some of the quick stoppers
This commit is contained in:
64
aiserver.py
64
aiserver.py
@@ -12,6 +12,8 @@ import random
|
||||
import shutil
|
||||
import eventlet
|
||||
|
||||
from modeling.inference_model import GenerationMode
|
||||
|
||||
eventlet.monkey_patch(all=True, thread=False, os=False)
|
||||
import os, inspect, contextlib, pickle
|
||||
os.system("")
|
||||
@@ -3266,7 +3268,16 @@ def check_for_backend_compilation():
|
||||
break
|
||||
koboldai_vars.checking = False
|
||||
|
||||
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False):
|
||||
def actionsubmit(
|
||||
data,
|
||||
actionmode=0,
|
||||
force_submit=False,
|
||||
force_prompt_gen=False,
|
||||
disable_recentrng=False,
|
||||
no_generate=False,
|
||||
ignore_aibusy=False,
|
||||
gen_mode=GenerationMode.STANDARD
|
||||
):
|
||||
# Ignore new submissions if the AI is currently busy
|
||||
if(koboldai_vars.aibusy):
|
||||
return
|
||||
@@ -3424,7 +3435,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
||||
|
||||
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
|
||||
# Off to the tokenizer!
|
||||
calcsubmit("")
|
||||
calcsubmit("", gen_mode=gen_mode)
|
||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||
data = ""
|
||||
force_submit = True
|
||||
@@ -3779,7 +3790,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
||||
#==================================================================#
|
||||
# Take submitted text and build the text to be given to generator
|
||||
#==================================================================#
|
||||
def calcsubmit(txt):
|
||||
def calcsubmit(txt, gen_mode=GenerationMode.STANDARD):
|
||||
anotetxt = "" # Placeholder for Author's Note text
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
@@ -3821,7 +3832,7 @@ def calcsubmit(txt):
|
||||
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
|
||||
|
||||
start_time = time.time()
|
||||
generate(subtxt, min, max, found_entries)
|
||||
generate(subtxt, min, max, found_entries, gen_mode=gen_mode)
|
||||
logger.debug("Submit: generate time {}s".format(time.time()-start_time))
|
||||
attention_bias.attention_bias = None
|
||||
|
||||
@@ -3889,7 +3900,7 @@ class HordeException(Exception):
|
||||
# Send text to generator and deal with output
|
||||
#==================================================================#
|
||||
|
||||
def generate(txt, minimum, maximum, found_entries=None):
|
||||
def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD):
|
||||
koboldai_vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
@@ -3911,7 +3922,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
# Submit input text to generator
|
||||
try:
|
||||
start_time = time.time()
|
||||
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
|
||||
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode)
|
||||
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
|
||||
except Exception as e:
|
||||
if(issubclass(type(e), lupa.LuaError)):
|
||||
@@ -6168,22 +6179,43 @@ def UI_2_delete_option(data):
|
||||
@socketio.on('submit')
|
||||
@logger.catch
|
||||
def UI_2_submit(data):
|
||||
if not koboldai_vars.noai and data['theme'] != "":
|
||||
if not koboldai_vars.noai and data['theme']:
|
||||
# Random prompt generation
|
||||
logger.debug("doing random prompt")
|
||||
memory = koboldai_vars.memory
|
||||
koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme'])
|
||||
koboldai_vars.lua_koboldbridge.feedback = None
|
||||
actionsubmit("", force_submit=True, force_prompt_gen=True)
|
||||
koboldai_vars.memory = memory
|
||||
else:
|
||||
logger.debug("doing normal input")
|
||||
koboldai_vars.actions.clear_unused_options()
|
||||
koboldai_vars.lua_koboldbridge.feedback = None
|
||||
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
|
||||
if koboldai_vars.actions.action_count == -1:
|
||||
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
|
||||
else:
|
||||
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
|
||||
return
|
||||
|
||||
logger.debug("doing normal input")
|
||||
koboldai_vars.actions.clear_unused_options()
|
||||
koboldai_vars.lua_koboldbridge.feedback = None
|
||||
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
|
||||
|
||||
gen_mode_name = data.get("gen_mode", None)
|
||||
gen_mode = {
|
||||
# If we don't have a gen mode, or it's None (the default), just do a
|
||||
# normal submission.
|
||||
None: GenerationMode.STANDARD,
|
||||
|
||||
# NOTE: forever should be a no-op on models that don't support
|
||||
# interrupting generation. This should be conveyed to the user by
|
||||
# graying out the option in the context menu.
|
||||
"forever": GenerationMode.FOREVER,
|
||||
|
||||
# The following gen modes require stopping criteria to be respected by
|
||||
# the backend:
|
||||
"until_eos": GenerationMode.UNTIL_EOS,
|
||||
"until_newline": GenerationMode.UNTIL_NEWLINE,
|
||||
"until_sentence_end": GenerationMode.UNTIL_SENTENCE_END,
|
||||
}.get(gen_mode_name, None)
|
||||
|
||||
if not gen_mode:
|
||||
raise RuntimeError(f"Unknown gen_mode '{gen_mode_name}'")
|
||||
|
||||
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode)
|
||||
|
||||
#==================================================================#
|
||||
# Event triggered when user clicks the submit button
|
||||
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from enum import Enum
|
||||
from logger import logger
|
||||
|
||||
import torch
|
||||
@@ -12,6 +14,7 @@ from transformers import (
|
||||
GPT2Tokenizer,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from modeling.stoppers import Stoppers
|
||||
from modeling.tokenizer import GenericTokenizer
|
||||
from modeling import logits_processors
|
||||
|
||||
@@ -154,6 +157,12 @@ class ModelCapabilities:
|
||||
# Some models need to warm up the TPU before use
|
||||
uses_tpu: bool = False
|
||||
|
||||
class GenerationMode(Enum):
|
||||
STANDARD = 0
|
||||
FOREVER = 1
|
||||
UNTIL_EOS = 2
|
||||
UNTIL_NEWLINE = 3
|
||||
UNTIL_SENTENCE_END = 4
|
||||
|
||||
class InferenceModel:
|
||||
"""Root class for all models."""
|
||||
@@ -256,6 +265,7 @@ class InferenceModel:
|
||||
self,
|
||||
text: list,
|
||||
found_entries: set,
|
||||
gen_mode: GenerationMode = GenerationMode.STANDARD,
|
||||
):
|
||||
"""Generate story text. Heavily tied to story-specific parameters; if
|
||||
you are making a new generation-based feature, consider `generate_raw()`.
|
||||
@@ -263,6 +273,7 @@ class InferenceModel:
|
||||
Args:
|
||||
text (list): Encoded input tokens
|
||||
found_entries (set): Entries found for Dynamic WI
|
||||
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
|
||||
|
||||
Raises:
|
||||
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
|
||||
@@ -358,6 +369,7 @@ class InferenceModel:
|
||||
seed=utils.koboldai_vars.seed
|
||||
if utils.koboldai_vars.full_determinism
|
||||
else None,
|
||||
gen_mode=gen_mode
|
||||
)
|
||||
logger.debug(
|
||||
"core_generate: run raw_generate pass {} {}s".format(
|
||||
@@ -532,6 +544,7 @@ class InferenceModel:
|
||||
found_entries: set = (),
|
||||
tpu_dynamic_inference: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
gen_mode: GenerationMode = GenerationMode.STANDARD,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||
@@ -547,6 +560,7 @@ class InferenceModel:
|
||||
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
|
||||
single_line (bool, optional): Generate one line only.. Defaults to False.
|
||||
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
|
||||
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
|
||||
|
||||
Raises:
|
||||
ValueError: If prompt type is weird
|
||||
@@ -568,6 +582,21 @@ class InferenceModel:
|
||||
"wi_scanner_excluded_keys", set()
|
||||
)
|
||||
|
||||
temp_stoppers = []
|
||||
|
||||
if gen_mode == GenerationMode.FOREVER:
|
||||
raise NotImplementedError()
|
||||
elif gen_mode == GenerationMode.UNTIL_EOS:
|
||||
# Still need to unban
|
||||
raise NotImplementedError()
|
||||
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
|
||||
# TODO: Look into replacing `single_line` with `generation_mode`
|
||||
temp_stoppers.append(Stoppers.newline_stopper)
|
||||
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
|
||||
temp_stoppers.append(Stoppers.sentence_end_stopper)
|
||||
|
||||
self.stopper_hooks += temp_stoppers
|
||||
|
||||
utils.koboldai_vars.inference_config.do_core = is_core
|
||||
gen_settings = GenerationSettings(*(generation_settings or {}))
|
||||
|
||||
@@ -604,6 +633,9 @@ class InferenceModel:
|
||||
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
|
||||
)
|
||||
|
||||
for stopper in temp_stoppers:
|
||||
self.stopper_hooks.remove(stopper)
|
||||
|
||||
return result
|
||||
|
||||
def generate(
|
||||
|
@@ -3,15 +3,12 @@ from __future__ import annotations
|
||||
import torch
|
||||
|
||||
import utils
|
||||
from modeling.inference_model import (
|
||||
InferenceModel,
|
||||
)
|
||||
|
||||
from modeling import inference_model
|
||||
|
||||
class Stoppers:
|
||||
@staticmethod
|
||||
def core_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_core:
|
||||
@@ -62,7 +59,7 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def dynamic_wi_scanner(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||
@@ -93,7 +90,7 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def chat_mode_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.chatmode:
|
||||
@@ -118,7 +115,7 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def stop_sequence_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
|
||||
@@ -145,14 +142,22 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def singleline_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
|
||||
"""Stop on occurances of newlines **if singleline is enabled**."""
|
||||
|
||||
# It might be better just to do this further up the line
|
||||
if not utils.koboldai_vars.singleline:
|
||||
return False
|
||||
return Stoppers.newline_stopper(model, input_ids)
|
||||
|
||||
@staticmethod
|
||||
def newline_stopper(
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
"""Stop on occurances of newlines."""
|
||||
# Keep track of presence of newlines in each sequence; we cannot stop a
|
||||
# batch member individually, so we must wait for all of them to contain
|
||||
# a newline.
|
||||
@@ -167,3 +172,30 @@ class Stoppers:
|
||||
del model.gen_state["newline_in_sequence"]
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def sentence_end_stopper(
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
"""Stops at the end of sentences."""
|
||||
|
||||
# TODO: Make this more robust
|
||||
SENTENCE_ENDS = [".", "?", "!"]
|
||||
|
||||
# We need to keep track of stopping for each batch, since we can't stop
|
||||
# one individually.
|
||||
if "sentence_end_in_sequence" not in model.gen_state:
|
||||
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
|
||||
|
||||
for sequence_idx, batch_sequence in enumerate(input_ids):
|
||||
decoded = model.tokenizer.decode(batch_sequence[-1])
|
||||
for end in SENTENCE_ENDS:
|
||||
if end in decoded:
|
||||
model.gen_state["sentence_end_sequence"][sequence_idx] = True
|
||||
break
|
||||
|
||||
if all(model.gen_state["sentence_end_sequence"]):
|
||||
del model.gen_state["sentence_end_sequence"]
|
||||
return True
|
||||
return False
|
@@ -149,13 +149,13 @@ const context_menu_actions = {
|
||||
{label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage},
|
||||
],
|
||||
"submit-button": [
|
||||
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: function(){}},
|
||||
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()},
|
||||
null,
|
||||
{label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
|
||||
{label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
|
||||
{label: "Generate Forever", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("forever")},
|
||||
{label: "Generate Until EOS", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_eos")},
|
||||
null,
|
||||
{label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
|
||||
{label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: function(){}},
|
||||
{label: "Finish Line", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_newline")},
|
||||
{label: "Finish Sentence", icon: "edit_off", enabledOn: "ALWAYS", click: () => storySubmit("until_sentence_end")},
|
||||
],
|
||||
"undo-button": [
|
||||
{label: "Undo", icon: "undo", enabledOn: "ALWAYS", click: function(){}},
|
||||
@@ -256,10 +256,17 @@ function disconnect() {
|
||||
document.getElementById("disconnect_message").classList.remove("hidden");
|
||||
}
|
||||
|
||||
function storySubmit() {
|
||||
function storySubmit(genMode=null) {
|
||||
const textInput = document.getElementById("input_text");
|
||||
const themeInput = document.getElementById("themetext");
|
||||
disruptStoryState();
|
||||
socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value});
|
||||
document.getElementById('input_text').value = '';
|
||||
socket.emit('submit', {
|
||||
data: textInput.value,
|
||||
theme: themeInput.value,
|
||||
gen_mode: genMode,
|
||||
});
|
||||
|
||||
textInput.value = '';
|
||||
document.getElementById('themetext').value = '';
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user