Merge pull request #414 from one-some/submit-ctx-menu

Submit context menu
This commit is contained in:
henk717
2023-07-30 01:58:44 +02:00
committed by GitHub
8 changed files with 240 additions and 47 deletions

View File

@@ -12,6 +12,8 @@ import random
import shutil import shutil
import eventlet import eventlet
from modeling.inference_model import GenerationMode
eventlet.monkey_patch(all=True, thread=False, os=False) eventlet.monkey_patch(all=True, thread=False, os=False)
import os, inspect, contextlib, pickle import os, inspect, contextlib, pickle
os.system("") os.system("")
@@ -1730,7 +1732,9 @@ def load_model(model_backend, initial_load=False):
with use_custom_unpickler(RestrictedUnpickler): with use_custom_unpickler(RestrictedUnpickler):
model = model_backends[model_backend] model = model_backends[model_backend]
koboldai_vars.supported_gen_modes = [x.value for x in model.get_supported_gen_modes()]
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel) model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
koboldai_vars.model = os.path.basename(os.path.normpath(model.path)) koboldai_vars.model = os.path.basename(os.path.normpath(model.path))
@@ -3209,7 +3213,16 @@ def check_for_backend_compilation():
break break
koboldai_vars.checking = False koboldai_vars.checking = False
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False): def actionsubmit(
data,
actionmode=0,
force_submit=False,
force_prompt_gen=False,
disable_recentrng=False,
no_generate=False,
ignore_aibusy=False,
gen_mode=GenerationMode.STANDARD
):
# Ignore new submissions if the AI is currently busy # Ignore new submissions if the AI is currently busy
if(koboldai_vars.aibusy): if(koboldai_vars.aibusy):
return return
@@ -3301,7 +3314,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
koboldai_vars.prompt = data koboldai_vars.prompt = data
# Clear the startup text from game screen # Clear the startup text from game screen
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1") emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1")
calcsubmit("") # Run the first action through the generator calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0): if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = "" data = ""
force_submit = True force_submit = True
@@ -3367,7 +3380,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating): if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
# Off to the tokenizer! # Off to the tokenizer!
calcsubmit("") calcsubmit("", gen_mode=gen_mode)
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0): if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
data = "" data = ""
force_submit = True force_submit = True
@@ -3722,7 +3735,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
#==================================================================# #==================================================================#
# Take submitted text and build the text to be given to generator # Take submitted text and build the text to be given to generator
#==================================================================# #==================================================================#
def calcsubmit(txt): def calcsubmit(txt, gen_mode=GenerationMode.STANDARD):
anotetxt = "" # Placeholder for Author's Note text anotetxt = "" # Placeholder for Author's Note text
forceanote = False # In case we don't have enough actions to hit A.N. depth forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth
@@ -3764,7 +3777,7 @@ def calcsubmit(txt):
logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time)) logger.debug("Submit: experimental_features time {}s".format(time.time()-start_time))
start_time = time.time() start_time = time.time()
generate(subtxt, min, max, found_entries) generate(subtxt, min, max, found_entries, gen_mode=gen_mode)
logger.debug("Submit: generate time {}s".format(time.time()-start_time)) logger.debug("Submit: generate time {}s".format(time.time()-start_time))
attention_bias.attention_bias = None attention_bias.attention_bias = None
@@ -3832,7 +3845,7 @@ class HordeException(Exception):
# Send text to generator and deal with output # Send text to generator and deal with output
#==================================================================# #==================================================================#
def generate(txt, minimum, maximum, found_entries=None): def generate(txt, minimum, maximum, found_entries=None, gen_mode=GenerationMode.STANDARD):
# Open up token stream # Open up token stream
emit("stream_tokens", True, broadcast=True, room="UI_2") emit("stream_tokens", True, broadcast=True, room="UI_2")
@@ -3861,7 +3874,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator # Submit input text to generator
try: try:
start_time = time.time() start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries) genout, already_generated = tpool.execute(model.core_generate, txt, found_entries, gen_mode=gen_mode)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):
@@ -6125,23 +6138,31 @@ def UI_2_delete_option(data):
@socketio.on('submit') @socketio.on('submit')
@logger.catch @logger.catch
def UI_2_submit(data): def UI_2_submit(data):
if not koboldai_vars.noai and data['theme'] != "": if not koboldai_vars.noai and data['theme']:
# Random prompt generation
logger.debug("doing random prompt") logger.debug("doing random prompt")
memory = koboldai_vars.memory memory = koboldai_vars.memory
koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme']) koboldai_vars.memory = "{}\n\nYou generate the following {} story concept :".format(koboldai_vars.memory, data['theme'])
koboldai_vars.lua_koboldbridge.feedback = None koboldai_vars.lua_koboldbridge.feedback = None
actionsubmit("", force_submit=True, force_prompt_gen=True) actionsubmit("", force_submit=True, force_prompt_gen=True)
koboldai_vars.memory = memory koboldai_vars.memory = memory
else: return
logger.debug("doing normal input")
koboldai_vars.actions.clear_unused_options() logger.debug("doing normal input")
koboldai_vars.lua_koboldbridge.feedback = None koboldai_vars.actions.clear_unused_options()
koboldai_vars.recentrng = koboldai_vars.recentrngm = None koboldai_vars.lua_koboldbridge.feedback = None
if koboldai_vars.actions.action_count == -1: koboldai_vars.recentrng = koboldai_vars.recentrngm = None
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode)
else: gen_mode_name = data.get("gen_mode", None) or "standard"
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode) try:
gen_mode = GenerationMode(gen_mode_name)
except ValueError:
# Invalid enum lookup!
gen_mode = GenerationMode.STANDARD
logger.warning(f"Unknown gen_mode '{gen_mode_name}', using STANDARD! Report this!")
actionsubmit(data['data'], actionmode=koboldai_vars.actionmode, gen_mode=gen_mode)
#==================================================================# #==================================================================#
# Event triggered when user clicks the submit button # Event triggered when user clicks the submit button
#==================================================================# #==================================================================#

View File

@@ -685,6 +685,7 @@ class model_settings(settings):
self._koboldai_vars = koboldai_vars self._koboldai_vars = koboldai_vars
self.alt_multi_gen = False self.alt_multi_gen = False
self.bit_8_available = None self.bit_8_available = None
self.supported_gen_modes = []
def reset_for_model_load(self): def reset_for_model_load(self):
self.simple_randomness = 0 #Set first as this affects other outputs self.simple_randomness = 0 #Set first as this affects other outputs

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import time import time
from typing import List, Optional, Union from typing import List, Optional, Union
from enum import Enum
from logger import logger from logger import logger
import torch import torch
@@ -12,6 +14,7 @@ from transformers import (
GPT2Tokenizer, GPT2Tokenizer,
AutoTokenizer, AutoTokenizer,
) )
from modeling.stoppers import Stoppers
from modeling.tokenizer import GenericTokenizer from modeling.tokenizer import GenericTokenizer
from modeling import logits_processors from modeling import logits_processors
@@ -144,7 +147,10 @@ class GenerationSettings:
class ModelCapabilities: class ModelCapabilities:
embedding_manipulation: bool = False embedding_manipulation: bool = False
post_token_hooks: bool = False post_token_hooks: bool = False
# Used to gauge if manual stopping is possible
stopper_hooks: bool = False stopper_hooks: bool = False
# TODO: Support non-live probabilities from APIs # TODO: Support non-live probabilities from APIs
post_token_probs: bool = False post_token_probs: bool = False
@@ -154,6 +160,12 @@ class ModelCapabilities:
# Some models need to warm up the TPU before use # Some models need to warm up the TPU before use
uses_tpu: bool = False uses_tpu: bool = False
class GenerationMode(Enum):
STANDARD = "standard"
FOREVER = "forever"
UNTIL_EOS = "until_eos"
UNTIL_NEWLINE = "until_newline"
UNTIL_SENTENCE_END = "until_sentence_end"
class InferenceModel: class InferenceModel:
"""Root class for all models.""" """Root class for all models."""
@@ -256,6 +268,7 @@ class InferenceModel:
self, self,
text: list, text: list,
found_entries: set, found_entries: set,
gen_mode: GenerationMode = GenerationMode.STANDARD,
): ):
"""Generate story text. Heavily tied to story-specific parameters; if """Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`. you are making a new generation-based feature, consider `generate_raw()`.
@@ -263,6 +276,7 @@ class InferenceModel:
Args: Args:
text (list): Encoded input tokens text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI found_entries (set): Entries found for Dynamic WI
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
Raises: Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
@@ -358,6 +372,7 @@ class InferenceModel:
seed=utils.koboldai_vars.seed seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism if utils.koboldai_vars.full_determinism
else None, else None,
gen_mode=gen_mode
) )
logger.debug( logger.debug(
"core_generate: run raw_generate pass {} {}s".format( "core_generate: run raw_generate pass {} {}s".format(
@@ -532,6 +547,7 @@ class InferenceModel:
found_entries: set = (), found_entries: set = (),
tpu_dynamic_inference: bool = False, tpu_dynamic_inference: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
gen_mode: GenerationMode = GenerationMode.STANDARD,
**kwargs, **kwargs,
) -> GenerationResult: ) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -547,6 +563,7 @@ class InferenceModel:
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False. is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False. single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to (). found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
Raises: Raises:
ValueError: If prompt type is weird ValueError: If prompt type is weird
@@ -568,6 +585,29 @@ class InferenceModel:
"wi_scanner_excluded_keys", set() "wi_scanner_excluded_keys", set()
) )
self.gen_state["allow_eos"] = False
temp_stoppers = []
if gen_mode not in self.get_supported_gen_modes():
gen_mode = GenerationMode.STANDARD
logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!")
if gen_mode == GenerationMode.FOREVER:
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_EOS:
self.gen_state["allow_eos"] = True
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper)
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
temp_stoppers.append(Stoppers.sentence_end_stopper)
self.stopper_hooks += temp_stoppers
utils.koboldai_vars.inference_config.do_core = is_core utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {})) gen_settings = GenerationSettings(*(generation_settings or {}))
@@ -604,6 +644,9 @@ class InferenceModel:
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second." f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
) )
for stopper in temp_stoppers:
self.stopper_hooks.remove(stopper)
return result return result
def generate( def generate(
@@ -620,3 +663,19 @@ class InferenceModel:
def _post_token_gen(self, input_ids: torch.LongTensor) -> None: def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
for hook in self.post_token_hooks: for hook in self.post_token_hooks:
hook(self, input_ids) hook(self, input_ids)
def get_supported_gen_modes(self) -> List[GenerationMode]:
"""Returns a list of compatible `GenerationMode`s for the current model.
Returns:
List[GenerationMode]: A list of compatible `GenerationMode`s.
"""
ret = [GenerationMode.STANDARD]
if self.capabilties.stopper_hooks:
ret += [
GenerationMode.FOREVER,
GenerationMode.UNTIL_NEWLINE,
GenerationMode.UNTIL_SENTENCE_END,
]
return ret

View File

@@ -34,6 +34,7 @@ from modeling.stoppers import Stoppers
from modeling.post_token_hooks import PostTokenHooks from modeling.post_token_hooks import PostTokenHooks
from modeling.inference_models.hf import HFInferenceModel from modeling.inference_models.hf import HFInferenceModel
from modeling.inference_model import ( from modeling.inference_model import (
GenerationMode,
GenerationResult, GenerationResult,
GenerationSettings, GenerationSettings,
ModelCapabilities, ModelCapabilities,
@@ -253,7 +254,10 @@ class HFTorchInferenceModel(HFInferenceModel):
assert kwargs.pop("logits_warper", None) is not None assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = KoboldLogitsWarperList() kwargs["logits_warper"] = KoboldLogitsWarperList()
if utils.koboldai_vars.newlinemode in ["s", "ns"]: if (
utils.koboldai_vars.newlinemode in ["s", "ns"]
and not m_self.gen_state["allow_eos"]
):
kwargs["eos_token_id"] = -1 kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2) kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs) return new_sample.old_sample(self, *args, **kwargs)
@@ -604,3 +608,9 @@ class HFTorchInferenceModel(HFInferenceModel):
self.breakmodel = False self.breakmodel = False
self.usegpu = False self.usegpu = False
return return
def get_supported_gen_modes(self) -> List[GenerationMode]:
# This changes a torch patch to disallow eos as a bad word.
return super().get_supported_gen_modes() + [
GenerationMode.UNTIL_EOS
]

View File

@@ -3,15 +3,12 @@ from __future__ import annotations
import torch import torch
import utils import utils
from modeling.inference_model import ( from modeling import inference_model
InferenceModel,
)
class Stoppers: class Stoppers:
@staticmethod @staticmethod
def core_stopper( def core_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.inference_config.do_core: if not utils.koboldai_vars.inference_config.do_core:
@@ -62,7 +59,7 @@ class Stoppers:
@staticmethod @staticmethod
def dynamic_wi_scanner( def dynamic_wi_scanner(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi: if not utils.koboldai_vars.inference_config.do_dynamic_wi:
@@ -93,7 +90,7 @@ class Stoppers:
@staticmethod @staticmethod
def chat_mode_stopper( def chat_mode_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.chatmode: if not utils.koboldai_vars.chatmode:
@@ -118,7 +115,7 @@ class Stoppers:
@staticmethod @staticmethod
def stop_sequence_stopper( def stop_sequence_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
@@ -154,14 +151,22 @@ class Stoppers:
@staticmethod @staticmethod
def singleline_stopper( def singleline_stopper(
model: InferenceModel, model: inference_model.InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline.""" """Stop on occurances of newlines **if singleline is enabled**."""
# It might be better just to do this further up the line
if not utils.koboldai_vars.singleline: if not utils.koboldai_vars.singleline:
return False return False
return Stoppers.newline_stopper(model, input_ids)
@staticmethod
def newline_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stop on occurances of newlines."""
# Keep track of presence of newlines in each sequence; we cannot stop a # Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain # batch member individually, so we must wait for all of them to contain
# a newline. # a newline.
@@ -176,3 +181,30 @@ class Stoppers:
del model.gen_state["newline_in_sequence"] del model.gen_state["newline_in_sequence"]
return True return True
return False return False
@staticmethod
def sentence_end_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stops at the end of sentences."""
# TODO: Make this more robust
SENTENCE_ENDS = [".", "?", "!"]
# We need to keep track of stopping for each batch, since we can't stop
# one individually.
if "sentence_end_in_sequence" not in model.gen_state:
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids):
decoded = model.tokenizer.decode(batch_sequence[-1])
for end in SENTENCE_ENDS:
if end in decoded:
model.gen_state["sentence_end_sequence"][sequence_idx] = True
break
if all(model.gen_state["sentence_end_sequence"]):
del model.gen_state["sentence_end_sequence"]
return True
return False

View File

@@ -2730,13 +2730,14 @@ body {
#context-menu > hr { #context-menu > hr {
/* Division Color*/ /* Division Color*/
border-top: 2px solid var(--context_menu_division); border-top: 2px solid var(--context_menu_division);
margin: 5px 5px; margin: 3px 5px;
} }
.context-menu-item { .context-menu-item {
padding: 5px; padding: 4px;
padding-right: 25px; padding-right: 25px;
min-width: 100px; min-width: 100px;
white-space: nowrap;
} }
.context-menu-item:hover { .context-menu-item:hover {
@@ -2747,11 +2748,16 @@ body {
.context-menu-item > .material-icons-outlined { .context-menu-item > .material-icons-outlined {
position: relative; position: relative;
top: 2px; top: 3px;
font-size: 15px; font-size: 15px;
margin-right: 5px; margin-right: 5px;
} }
.context-menu-item > .context-menu-label {
position: relative;
top: 1px;
}
/* Substitutions */ /* Substitutions */
#Substitutions { #Substitutions {
margin-left: 10px; margin-left: 10px;

View File

@@ -85,6 +85,7 @@ let story_id = -1;
var dirty_chunks = []; var dirty_chunks = [];
var initial_socketio_connection_occured = false; var initial_socketio_connection_occured = false;
var selected_model_data; var selected_model_data;
var supported_gen_modes = [];
var privacy_mode_enabled = false; var privacy_mode_enabled = false;
var ai_busy = false; var ai_busy = false;
var can_show_options = false; var can_show_options = false;
@@ -162,7 +163,36 @@ const context_menu_actions = {
"wi-img-upload-button": [ "wi-img-upload-button": [
{label: "Upload Image", icon: "file_upload", enabledOn: "ALWAYS", click: wiImageReplace}, {label: "Upload Image", icon: "file_upload", enabledOn: "ALWAYS", click: wiImageReplace},
{label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage}, {label: "Use Generated Image", icon: "image", enabledOn: "GENERATED-IMAGE", click: wiImageUseGeneratedImage},
] ],
"submit-button": [
{label: "Generate", icon: "edit", enabledOn: "ALWAYS", click: () => storySubmit()},
null,
{
label: "Generate Forever",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("forever"),
click: () => storySubmit("forever")
},
{
label: "Generate Until EOS",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("until_eos"),
click: () => storySubmit("until_eos")
},
null,
{
label: "Finish Line",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("until_newline"),
click: () => storySubmit("until_newline")
},
{
label: "Finish Sentence",
icon: "edit_off",
enabledOn: () => supported_gen_modes.includes("until_sentence_end"),
click: () => storySubmit("until_sentence_end")
},
],
}; };
let context_menu_cache = []; let context_menu_cache = [];
@@ -254,10 +284,17 @@ function disconnect() {
document.getElementById("disconnect_message").classList.remove("hidden"); document.getElementById("disconnect_message").classList.remove("hidden");
} }
function storySubmit() { function storySubmit(genMode=null) {
const textInput = document.getElementById("input_text");
const themeInput = document.getElementById("themetext");
disruptStoryState(); disruptStoryState();
socket.emit('submit', {'data': document.getElementById('input_text').value, 'theme': document.getElementById('themetext').value}); socket.emit('submit', {
document.getElementById('input_text').value = ''; data: textInput.value,
theme: themeInput.value,
gen_mode: genMode,
});
textInput.value = '';
document.getElementById('themetext').value = ''; document.getElementById('themetext').value = '';
} }
@@ -1009,6 +1046,9 @@ function var_changed(data) {
//special case for welcome text since we want to allow HTML //special case for welcome text since we want to allow HTML
} else if (data.classname == 'model' && data.name == 'welcome') { } else if (data.classname == 'model' && data.name == 'welcome') {
document.getElementById('welcome_text').innerHTML = data.value; document.getElementById('welcome_text').innerHTML = data.value;
//Special case for permitted generation modes
} else if (data.classname == 'model' && data.name == 'supported_gen_modes') {
supported_gen_modes = data.value;
//Basic Data Syncing //Basic Data Syncing
} else { } else {
var elements_to_change = document.getElementsByClassName("var_sync_"+data.classname.replace(" ", "_")+"_"+data.name.replace(" ", "_")); var elements_to_change = document.getElementsByClassName("var_sync_"+data.classname.replace(" ", "_")+"_"+data.name.replace(" ", "_"));
@@ -5976,8 +6016,21 @@ function position_context_menu(contextMenu, x, y) {
right: x + width, right: x + width,
}; };
// Slide over if running against the window bounds.
if (farMenuBounds.right > bounds.right) x -= farMenuBounds.right - bounds.right; if (farMenuBounds.right > bounds.right) x -= farMenuBounds.right - bounds.right;
if (farMenuBounds.bottom > bounds.bottom) y -= farMenuBounds.bottom - bounds.bottom;
if (farMenuBounds.bottom > bounds.bottom) {
// We've hit the bottom.
// The old algorithm pushed the menu against the wall, similar to what's
// done on the x-axis:
// y -= farMenuBounds.bottom - bounds.bottom;
// But now, we make the box change its emission direction from the cursor:
y -= (height + 5);
// The main advantage of this approach is that the cursor is never directly
// placed above a context menu item immediately after activating the context
// menu. (Thus the 5px offset also added)
}
contextMenu.style.left = `${x}px`; contextMenu.style.left = `${x}px`;
contextMenu.style.top = `${y}px`; contextMenu.style.top = `${y}px`;
@@ -6252,21 +6305,23 @@ process_cookies();
continue; continue;
} }
const enableCriteriaIsFunction = typeof action.enabledOn === "function"
let item = $e("div", contextMenu, { const itemEl = $e("div", contextMenu, {
classes: ["context-menu-item", "noselect", `context-menu-${key}`], classes: ["context-menu-item", "noselect", `context-menu-${key}`],
"enabled-on": action.enabledOn, "enabled-on": enableCriteriaIsFunction ? "CALLBACK" : action.enabledOn,
"cache-index": context_menu_cache.length "cache-index": context_menu_cache.length
}); });
itemEl.enabledOnCallback = action.enabledOn;
context_menu_cache.push({shouldShow: action.shouldShow}); context_menu_cache.push({shouldShow: action.shouldShow});
let icon = $e("span", item, {classes: ["material-icons-outlined"], innerText: action.icon}); const icon = $e("span", itemEl, {classes: ["material-icons-outlined"], innerText: action.icon});
item.append(action.label); $e("span", itemEl, {classes: ["context-menu-label"], innerText: action.label});
item.addEventListener("mousedown", e => e.preventDefault()); itemEl.addEventListener("mousedown", e => e.preventDefault());
// Expose the "summonEvent" to enable access to original context menu target. // Expose the "summonEvent" to enable access to original context menu target.
item.addEventListener("click", () => action.click(summonEvent)); itemEl.addEventListener("click", () => action.click(summonEvent));
} }
} }
@@ -6289,6 +6344,10 @@ process_cookies();
// Show only applicable actions in the context menu // Show only applicable actions in the context menu
let contextMenuType = target.getAttribute("context-menu"); let contextMenuType = target.getAttribute("context-menu");
// If context menu is not present, return
if (!context_menu_actions[contextMenuType]) return;
for (const contextMenuItem of contextMenu.childNodes) { for (const contextMenuItem of contextMenu.childNodes) {
let shouldShow = contextMenuItem.classList.contains(`context-menu-${contextMenuType}`); let shouldShow = contextMenuItem.classList.contains(`context-menu-${contextMenuType}`);
@@ -6316,10 +6375,10 @@ process_cookies();
// Disable non-applicable items // Disable non-applicable items
$(".context-menu-item").addClass("disabled"); $(".context-menu-item").addClass("disabled");
// A selection is made // A selection is made
if (getSelectionText()) $(".context-menu-item[enabled-on=SELECTION]").removeClass("disabled"); if (getSelectionText()) $(".context-menu-item[enabled-on=SELECTION]").removeClass("disabled");
// The caret is placed // The caret is placed
if (get_caret_position(target) !== null) $(".context-menu-item[enabled-on=CARET]").removeClass("disabled"); if (get_caret_position(target) !== null) $(".context-menu-item[enabled-on=CARET]").removeClass("disabled");
@@ -6328,6 +6387,11 @@ process_cookies();
$(".context-menu-item[enabled-on=ALWAYS]").removeClass("disabled"); $(".context-menu-item[enabled-on=ALWAYS]").removeClass("disabled");
for (const contextMenuItem of document.querySelectorAll(".context-menu-item[enabled-on=CALLBACK]")) {
if (!contextMenuItem.enabledOnCallback()) continue;
contextMenuItem.classList.remove("disabled");
}
// Make sure hr isn't first or last visible element // Make sure hr isn't first or last visible element
let visibles = []; let visibles = [];
for (const item of contextMenu.children) { for (const item of contextMenu.children) {

View File

@@ -110,9 +110,9 @@
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='play_pause_tts()' aria-label="play"><span id="play_tts" class="material-icons-outlined" style="font-size: 1.4em;">play_arrow</span></button> <button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='play_pause_tts()' aria-label="play"><span id="play_tts" class="material-icons-outlined" style="font-size: 1.4em;">play_arrow</span></button>
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='stop_tts()' aria-label="play"><span id="stop_tts" class="material-icons-outlined" style="font-size: 1.4em;">stop</span></button> <button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='stop_tts()' aria-label="play"><span id="stop_tts" class="material-icons-outlined" style="font-size: 1.4em;">stop</span></button>
</span> </span>
<button type="button" class="btn action_button submit var_sync_alt_system_aibusy" system_aibusy=False id="btnsubmit" onclick="storySubmit();">Submit</button> <button type="button" class="btn action_button submit var_sync_alt_system_aibusy" system_aibusy=False id="btnsubmit" onclick="storySubmit();" context-menu="submit-button">Submit</button>
<button type="button" class="btn action_button submited var_sync_alt_system_aibusy" system_aibusy=False id="btnsent"><img id="thinking" src="static/thinking.gif" class="force_center" onclick="socket.emit('abort','');"></button> <button type="button" class="btn action_button submited var_sync_alt_system_aibusy" system_aibusy=False id="btnsent"><img id="thinking" src="static/thinking.gif" class="force_center" onclick="socket.emit('abort','');"></button>
<button type="button" class="btn action_button back var_sync_alt_system_aibusy" system_aibusy=False onclick="storyBack();" aria-label="undo"><span class="material-icons-outlined" style="font-size: 1.4em;">replay</span></button> <button type="button" class="btn action_button back var_sync_alt_system_aibusy" system_aibusy=False onclick="storyBack();" aria-label="undo" context-menu="undo-button"><span class="material-icons-outlined" style="font-size: 1.4em;">replay</span></button>
<button type="button" class="btn action_button redo var_sync_alt_system_aibusy" system_aibusy=False onclick="storyRedo();" aria-label="redo"><span class="material-icons-outlined" style="font-size: 1.4em;">arrow_forward</span></button> <button type="button" class="btn action_button redo var_sync_alt_system_aibusy" system_aibusy=False onclick="storyRedo();" aria-label="redo"><span class="material-icons-outlined" style="font-size: 1.4em;">arrow_forward</span></button>
<button type="button" class="btn action_button retry var_sync_alt_system_aibusy" system_aibusy=False onclick="storyRetry();" aria-label="retry"><span class="material-icons-outlined" style="font-size: 1.4em;">autorenew</span></button> <button type="button" class="btn action_button retry var_sync_alt_system_aibusy" system_aibusy=False onclick="storyRetry();" aria-label="retry"><span class="material-icons-outlined" style="font-size: 1.4em;">autorenew</span></button>
</div> </div>