Merge pull request #414 from one-some/submit-ctx-menu

Submit context menu
This commit is contained in:
henk717
2023-07-30 01:58:44 +02:00
committed by GitHub
8 changed files with 240 additions and 47 deletions

View File

@@ -12,6 +12,8 @@ import random
import shutil
import eventlet
from modeling.inference_model import GenerationMode
eventlet.monkey_patch(all=True, thread=False, os=False)
import os, inspect, contextlib, pickle
os.system("")
@@ -1730,7 +1732,9 @@ def load_model(model_backend, initial_load=False):
with use_custom_unpickler(RestrictedUnpickler):
model = model_backends[model_backend]
koboldai_vars.supported_gen_modes = [x.value for x in model.get_supported_gen_modes()]
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
koboldai_vars.model = os.path.basename(os.path.normpath(model.path))
@@ -3209,7 +3213,16 @@ def check_for_backend_compilation():
break
koboldai_vars.checking = False
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False):
def actionsubmit(
data,
actionmode=0,
force_submit=False,
force_prompt_gen=False,
disable_recentrng=False,
no_generate=False,
ignore_aibusy=False,
gen_mode=GenerationMode.STANDARD
):
# Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy):
return
@@ -3301,7 +3314,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
koboldai_vars.prompt = data
# Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1")
calcsubmit("") # Run the first action through the generator
calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = ""
force_submit = True
@@ -3367,7 +3380,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
# Off to the tokenizer!
calcsubmit("")
calcsubmit("", gen_mode=gen_mode)
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = ""
force_submit = True
@@ -3722,7 +3735,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
#==================================================================#
# Take submitted text and build the text to be given to generator
#==================================================================#
def calcsubmit(txt):
def calcsubmit(txt, gen_mode=GenerationMode.STANDARD):
anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
@@ -3764,7 +3777,7 @@ def calcsubmit(txt):
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
start_time = time.time()
generate(subtxt, min, max, found_entries)
generate(subtxt, min, max, found_entries, gen_mode=gen_mode)
logger.debug("Submit: generate time {}s".format(time.time()-start_time))
attention_bias.attention_bias = None
@@ -3832,7 +3845,7 @@ class HordeException(Exception):
# Send text to generator and deal with output
#==================================================================#
def generate(txt, minimum, maximum, found_entries=None):
def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD):
# Open up token stream
emit("stream_tokens", True, broadcast=True, room="UI_2")
@@ -3861,7 +3874,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator
try:
start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
@@ -6125,23 +6138,31 @@ def UI_2_delete_option(data):
@socketio.on('submit')
@logger.catch
def UI_2_submit(data):
if not koboldai_vars.noai and data['theme'] != "":
if not koboldai_vars.noai and data['theme']:
# Random prompt generation
logger.debug("doing random prompt")
memory = koboldai_vars.memory
koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme'])
koboldai_vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True, force_prompt_gen=True)
koboldai_vars.memory = memory
else:
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options()
koboldai_vars.lua_koboldbridge.feedback = None
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
if koboldai_vars.actions.action_count == -1:
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
else:
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
return
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options()
koboldai_vars.lua_koboldbridge.feedback = None
koboldai_vars.recentrng = koboldai_vars.recentrngm = None
gen_mode_name = data.get("gen_mode", None) or "standard"
try:
gen_mode = GenerationMode(gen_mode_name)
except ValueError:
# Invalid enum lookup!
gen_mode = GenerationMode.STANDARD
logger.warning(f"Unknown gen_mode '{gen_mode_name}', using STANDARD! Report this!")
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode)
#==================================================================#
# Event triggered when user clicks the submit button
#==================================================================#

View File

@@ -685,6 +685,7 @@ class model_settings(settings):
self._koboldai_vars = koboldai_vars
self.alt_multi_gen = False
self.bit_8_available = None
self.supported_gen_modes = []
def reset_for_model_load(self):
self.simple_randomness = 0 #Set first as this affects other outputs

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
import time
from typing import List, Optional, Union
from enum import Enum
from logger import logger
import torch
@@ -12,6 +14,7 @@ from transformers import (
GPT2Tokenizer,
AutoTokenizer,
)
from modeling.stoppers import Stoppers
from modeling.tokenizer import GenericTokenizer
from modeling import logits_processors
@@ -144,7 +147,10 @@ class GenerationSettings:
class ModelCapabilities:
embedding_manipulation: bool = False
post_token_hooks: bool = False
# Used to gauge if manual stopping is possible
stopper_hooks: bool = False
# TODO: Support non-live probabilities from APIs
post_token_probs: bool = False
@@ -154,6 +160,12 @@ class ModelCapabilities:
# Some models need to warm up the TPU before use
uses_tpu: bool = False
class GenerationMode(Enum):
STANDARD = "standard"
FOREVER = "forever"
UNTIL_EOS = "until_eos"
UNTIL_NEWLINE = "until_newline"
UNTIL_SENTENCE_END = "until_sentence_end"
class InferenceModel:
"""Root class for all models."""
@@ -256,6 +268,7 @@ class InferenceModel:
self,
text: list,
found_entries: set,
gen_mode: GenerationMode = GenerationMode.STANDARD,
):
"""Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`.
@@ -263,6 +276,7 @@ class InferenceModel:
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
@@ -358,6 +372,7 @@ class InferenceModel:
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
gen_mode=gen_mode
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
@@ -532,6 +547,7 @@ class InferenceModel:
found_entries: set = (),
tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
gen_mode: GenerationMode = GenerationMode.STANDARD,
**kwargs,
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -547,6 +563,7 @@ class InferenceModel:
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
Raises:
ValueError: If prompt type is weird
@@ -568,6 +585,29 @@ class InferenceModel:
"wi_scanner_excluded_keys", set()
)
self.gen_state["allow_eos"] = False
temp_stoppers = []
if gen_mode not in self.get_supported_gen_modes():
gen_mode = GenerationMode.STANDARD
logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!")
if gen_mode == GenerationMode.FOREVER:
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_EOS:
self.gen_state["allow_eos"] = True
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper)
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
temp_stoppers.append(Stoppers.sentence_end_stopper)
self.stopper_hooks += temp_stoppers
utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {}))
@@ -604,6 +644,9 @@ class InferenceModel:
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
)
for stopper in temp_stoppers:
self.stopper_hooks.remove(stopper)
return result
def generate(
@@ -620,3 +663,19 @@ class InferenceModel:
def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
for hook in self.post_token_hooks:
hook(self, input_ids)
def get_supported_gen_modes(self) -> List[GenerationMode]:
"""Returns a list of compatible `GenerationMode`s for the current model.
Returns:
List[GenerationMode]: A list of compatible `GenerationMode`s.
"""
ret = [GenerationMode.STANDARD]
if self.capabilties.stopper_hooks:
ret += [
GenerationMode.FOREVER,
GenerationMode.UNTIL_NEWLINE,
GenerationMode.UNTIL_SENTENCE_END,
]
return ret

View File

@@ -34,6 +34,7 @@ from modeling.stoppers import Stoppers
from modeling.post_token_hooks import PostTokenHooks
from modeling.inference_models.hf import HFInferenceModel
from modeling.inference_model import (
GenerationMode,
GenerationResult,
GenerationSettings,
ModelCapabilities,
@@ -253,7 +254,10 @@ class HFTorchInferenceModel(HFInferenceModel):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = KoboldLogitsWarperList()
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
if (
utils.koboldai_vars.newlinemode in ["s", "ns"]
and not m_self.gen_state["allow_eos"]
):
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
@@ -604,3 +608,9 @@ class HFTorchInferenceModel(HFInferenceModel):
self.breakmodel = False
self.usegpu = False
return
def get_supported_gen_modes(self) -> List[GenerationMode]:
# This changes a torch patch to disallow eos as a bad word.
return super().get_supported_gen_modes() + [
GenerationMode.UNTIL_EOS
]

View File

@@ -3,15 +3,12 @@ from __future__ import annotations
import torch
import utils
from modeling.inference_model import (
InferenceModel,
)
from modeling import inference_model
class Stoppers:
@staticmethod
def core_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_core:
@@ -62,7 +59,7 @@ class Stoppers:
@staticmethod
def dynamic_wi_scanner(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
@@ -93,7 +90,7 @@ class Stoppers:
@staticmethod
def chat_mode_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.chatmode:
@@ -118,7 +115,7 @@ class Stoppers:
@staticmethod
def stop_sequence_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
@@ -154,14 +151,22 @@ class Stoppers:
@staticmethod
def singleline_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
"""Stop on occurances of newlines **if singleline is enabled**."""
# It might be better just to do this further up the line
if not utils.koboldai_vars.singleline:
return False
return Stoppers.newline_stopper(model, input_ids)
@staticmethod
def newline_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stop on occurances of newlines."""
# Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain
# a newline.
@@ -176,3 +181,30 @@ class Stoppers:
del model.gen_state["newline_in_sequence"]
return True
return False
@staticmethod
def sentence_end_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stops at the end of sentences."""
# TODO: Make this more robust
SENTENCE_ENDS = [".", "?", "!"]
# We need to keep track of stopping for each batch, since we can't stop
# one individually.
if "sentence_end_in_sequence" not in model.gen_state:
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids):
decoded = model.tokenizer.decode(batch_sequence[-1])
for end in SENTENCE_ENDS:
if end in decoded:
model.gen_state["sentence_end_sequence"][sequence_idx] = True
break
if all(model.gen_state["sentence_end_sequence"]):
del model.gen_state["sentence_end_sequence"]
return True
return False

View File

@@ -2730,13 +2730,14 @@ body {
#context-menu > hr {
/* Division Color*/
border-top: 2px solid var(--context_menu_division);
margin: 5px 5px;
margin: 3px 5px;
}
.context-menu-item {
padding: 5px;
padding: 4px;
padding-right: 25px;
min-width: 100px;
white-space: nowrap;
}
.context-menu-item:hover {
@@ -2747,11 +2748,16 @@ body {
.context-menu-item > .material-icons-outlined {
position: relative;
top: 2px;
top: 3px;
font-size: 15px;
margin-right: 5px;
}
.context-menu-item > .context-menu-label {
position: relative;
top: 1px;
}
/* Substitutions */
#Substitutions {
margin-left: 10px;

View File

@@ -85,6 +85,7 @@ let story_id = -1;
var dirty_chunks = [];
var initial_socketio_connection_occured = false;
var selected_model_data;
var supported_gen_modes = [];
var privacy_mode_enabled = false;
var ai_busy = false;
var can_show_options = false;
@@ -162,7 +163,36 @@ const context_menu_actions = {
"wi-img-upload-button": [
{label: "Upload Image", icon: "file_upload", enabledOn: "ALWAYS", click: wiImageReplace},
{label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage},
]
],
"submit-button": [
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()},
null,
{
label: "Generate Forever",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("forever"),
click: () => storySubmit("forever")
},
{
label: "Generate Until EOS",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("until_eos"),
click: () => storySubmit("until_eos")
},
null,
{
label: "Finish Line",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("until_newline"),
click: () => storySubmit("until_newline")
},
{
label: "Finish Sentence",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("until_sentence_end"),
click: () => storySubmit("until_sentence_end")
},
],
};
let context_menu_cache = [];
@@ -254,10 +284,17 @@ function disconnect() {
document.getElementById("disconnect_message").classList.remove("hidden");
}
function storySubmit() {
function storySubmit(genMode=null) {
const textInput = document.getElementById("input_text");
const themeInput = document.getElementById("themetext");
disruptStoryState();
socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value});
document.getElementById('input_text').value = '';
socket.emit('submit', {
data: textInput.value,
theme: themeInput.value,
gen_mode: genMode,
});
textInput.value = '';
document.getElementById('themetext').value = '';
}
@@ -1009,6 +1046,9 @@ function var_changed(data) {
//special case for welcome text since we want to allow HTML
} else if (data.classname == 'model' && data.name == 'welcome') {
document.getElementById('welcome_text').innerHTML = data.value;
//Special case for permitted generation modes
} else if (data.classname == 'model' && data.name == 'supported_gen_modes') {
supported_gen_modes = data.value;
//Basic Data Syncing
} else {
var elements_to_change = document.getElementsByClassName("var_sync_"+data.classname.replace(" ", "_")+"_"+data.name.replace(" ", "_"));
@@ -5976,8 +6016,21 @@ function position_context_menu(contextMenu, x, y) {
right: x + width,
};
// Slide over if running against the window bounds.
if (farMenuBounds.right > bounds.right) x -= farMenuBounds.right - bounds.right;
if (farMenuBounds.bottom > bounds.bottom) y -= farMenuBounds.bottom - bounds.bottom;
if (farMenuBounds.bottom > bounds.bottom) {
// We've hit the bottom.
// The old algorithm pushed the menu against the wall, similar to what's
// done on the x-axis:
// y -= farMenuBounds.bottom - bounds.bottom;
// But now, we make the box change its emission direction from the cursor:
y -= (height + 5);
// The main advantage of this approach is that the cursor is never directly
// placed above a context menu item immediately after activating the context
// menu. (Thus the 5px offset also added)
}
contextMenu.style.left = `${x}px`;
contextMenu.style.top = `${y}px`;
@@ -6252,21 +6305,23 @@ process_cookies();
continue;
}
const enableCriteriaIsFunction = typeof action.enabledOn === "function"
let item = $e("div", contextMenu, {
const itemEl = $e("div", contextMenu, {
classes: ["context-menu-item", "noselect", `context-menu-${key}`],
"enabled-on": action.enabledOn,
"enabled-on": enableCriteriaIsFunction ? "CALLBACK" : action.enabledOn,
"cache-index": context_menu_cache.length
});
itemEl.enabledOnCallback = action.enabledOn;
context_menu_cache.push({shouldShow: action.shouldShow});
let icon = $e("span", item, {classes: ["material-icons-outlined"], innerText: action.icon});
item.append(action.label);
const icon = $e("span", itemEl, {classes: ["material-icons-outlined"], innerText: action.icon});
$e("span", itemEl, {classes: ["context-menu-label"], innerText: action.label});
item.addEventListener("mousedown", e => e.preventDefault());
itemEl.addEventListener("mousedown", e => e.preventDefault());
// Expose the "summonEvent" to enable access to original context menu target.
item.addEventListener("click", () => action.click(summonEvent));
itemEl.addEventListener("click", () => action.click(summonEvent));
}
}
@@ -6289,6 +6344,10 @@ process_cookies();
// Show only applicable actions in the context menu
let contextMenuType = target.getAttribute("context-menu");
// If context menu is not present, return
if (!context_menu_actions[contextMenuType]) return;
for (const contextMenuItem of contextMenu.childNodes) {
let shouldShow = contextMenuItem.classList.contains(`context-menu-${contextMenuType}`);
@@ -6316,10 +6375,10 @@ process_cookies();
// Disable non-applicable items
$(".context-menu-item").addClass("disabled");
// A selection is made
if (getSelectionText()) $(".context-menu-item[enabled-on=SELECTION]").removeClass("disabled");
// The caret is placed
if (get_caret_position(target) !== null) $(".context-menu-item[enabled-on=CARET]").removeClass("disabled");
@@ -6328,6 +6387,11 @@ process_cookies();
$(".context-menu-item[enabled-on=ALWAYS]").removeClass("disabled");
for (const contextMenuItem of document.querySelectorAll(".context-menu-item[enabled-on=CALLBACK]")) {
if (!contextMenuItem.enabledOnCallback()) continue;
contextMenuItem.classList.remove("disabled");
}
// Make sure hr isn't first or last visible element
let visibles = [];
for (const item of contextMenu.children) {

View File

@@ -110,9 +110,9 @@
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='play_pause_tts()' aria-label="play"><span id="play_tts" class="material-icons-outlined" style="font-size: 1.4em;">play_arrow</span></button>
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='stop_tts()' aria-label="play"><span id="stop_tts" class="material-icons-outlined" style="font-size: 1.4em;">stop</span></button>
</span>
<button type="button" class="btn action_button submit var_sync_alt_system_aibusy" system_aibusy=False id="btnsubmit" onclick="storySubmit();">Submit</button>
<button type="button" class="btn action_button submit var_sync_alt_system_aibusy" system_aibusy=False id="btnsubmit" onclick="storySubmit();" context-menu="submit-button">Submit</button>
<button type="button" class="btn action_button submited var_sync_alt_system_aibusy" system_aibusy=False id="btnsent"><img id="thinking" src="static/thinking.gif" class="force_center" onclick="socket.emit('abort','');"></button>
<button type="button" class="btn action_button back var_sync_alt_system_aibusy" system_aibusy=False onclick="storyBack();" aria-label="undo"><span class="material-icons-outlined" style="font-size: 1.4em;">replay</span></button>
<button type="button" class="btn action_button back var_sync_alt_system_aibusy" system_aibusy=False onclick="storyBack();" aria-label="undo" context-menu="undo-button"><span class="material-icons-outlined" style="font-size: 1.4em;">replay</span></button>
<button type="button" class="btn action_button redo var_sync_alt_system_aibusy" system_aibusy=False onclick="storyRedo();" aria-label="redo"><span class="material-icons-outlined" style="font-size: 1.4em;">arrow_forward</span></button>
<button type="button" class="btn action_button retry var_sync_alt_system_aibusy" system_aibusy=False onclick="storyRetry();" aria-label="retry"><span class="material-icons-outlined" style="font-size: 1.4em;">autorenew</span></button>
</div>