diff --git a/aiserver.py b/aiserver.py index 34cf9fd5..7416fc20 100644 --- a/aiserver.py +++ b/aiserver.py @@ -40,7 +40,6 @@ import packaging import packaging.version import contextlib import traceback -import threading import markdown import bleach import itertools @@ -63,7 +62,6 @@ import sys import gc import lupa -import importlib # KoboldAI import fileops @@ -83,7 +81,7 @@ import transformers.generation_utils # Text2img import base64 -from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, PngImagePlugin +from PIL import Image from io import BytesIO global tpu_mtj_backend @@ -2220,28 +2218,90 @@ def patch_transformers(): # There were no matches, so just begin at the beginning. return 0 + def _allow_leftwards_tampering(self, phrase: str) -> bool: + """Determines if a phrase should be tampered with from the left in + the "soft" token encoding mode.""" + + if phrase[0] in [".", "?", "!", ";", ":", "\n"]: + return False + return True + + def _get_token_sequence(self, phrase: str) -> List[List]: + """Convert the phrase string into a list of encoded biases, each + one being a list of tokens. How this is done is determined by the + phrase's format: + + - If the phrase is surrounded by square brackets ([]), the tokens + will be the phrase split by commas (,). If a "token" isn't + actually a number, it will be skipped. NOTE: Tokens output by + this may not be in the model's vocabulary, and such tokens + should be ignored later in the pipeline. + - If the phrase is surrounded by curly brackets ({}), the phrase + will be directly encoded with no synonym biases and no fancy + tricks. + - Otherwise, the phrase will be encoded, with close deviations + being included as synonym biases. + """ + + # TODO: Cache these tokens, invalidate when model or bias is + # changed. + + # Handle direct token id input + if phrase.startswith("[") and phrase.endswith("]"): + no_brackets = phrase[1:-1] + ret = [] + for token_id in no_brackets.split(","): + try: + ret.append(int(token_id)) + except ValueError: + # Ignore non-numbers. Rascals! + pass + return [ret] + + # Handle direct phrases + if phrase.startswith("{") and phrase.endswith("}"): + no_brackets = phrase[1:-1] + return [tokenizer.encode(no_brackets)] + + # Handle untamperable phrases + if not self._allow_leftwards_tampering(phrase): + return [tokenizer.encode(phrase)] + + # Handle slight alterations to original phrase + phrase = phrase.strip(" ") + ret = [] + + for alt_phrase in [phrase, f" {phrase}"]: + ret.append(tokenizer.encode(alt_phrase)) + + return ret + def _get_biased_tokens(self, input_ids: List) -> Dict: # TODO: Different "bias slopes"? ret = {} for phrase, _bias in koboldai_vars.biases.items(): bias_score, completion_threshold = _bias - # TODO: Cache these tokens, invalidate when model or bias is - # changed. - token_seq = tokenizer.encode(phrase) - bias_index = self._find_intersection(input_ids, token_seq) + token_seqs = self._get_token_sequence(phrase) + variant_deltas = {} + for token_seq in token_seqs: + bias_index = self._find_intersection(input_ids, token_seq) - # Ensure completion after completion_threshold tokens - # Only provide a positive bias when the base bias score is positive. - if bias_score > 0 and bias_index + 1 > completion_threshold: - bias_score = 999 + # Ensure completion after completion_threshold tokens + # Only provide a positive bias when the base bias score is positive. + if bias_score > 0 and bias_index + 1 > completion_threshold: + bias_score = 999 - token_to_bias = token_seq[bias_index] - # If multiple phrases bias the same token, add the modifiers together. - if token_to_bias in ret: - ret[token_to_bias] += bias_score - else: - ret[token_to_bias] = bias_score + token_to_bias = token_seq[bias_index] + variant_deltas[token_to_bias] = bias_score + + # If multiple phrases bias the same token, add the modifiers + # together. This should NOT be applied to automatic variants + for token_to_bias, bias_score in variant_deltas.items(): + if token_to_bias in ret: + ret[token_to_bias] += bias_score + else: + ret[token_to_bias] = bias_score return ret def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -7847,7 +7907,6 @@ def final_startup(): return utils.decodenewlines(tokenizer.decode([25678, 559])) tokenizer.encode(utils.encodenewlines("eunoia")) - #threading.Thread(target=__preempt_tokenizer).start() tpool.execute(__preempt_tokenizer) # Load soft prompt specified by the settings file, if applicable @@ -7865,18 +7924,6 @@ def final_startup(): if(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): soft_tokens = tpumtjgetsofttokens() if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)): - #threading.Thread( - # target=tpu_mtj_backend.infer_dynamic, - # args=(np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)),), - # kwargs={ - # "soft_embeddings": koboldai_vars.sp, - # "soft_tokens": soft_tokens, - # "gen_len": 1, - # "use_callback": False, - # "numseqs": koboldai_vars.numseqs, - # "excluded_world_info": list(set() for _ in range(koboldai_vars.numseqs)), - # }, - #).start() tpool.execute(tpu_mtj_backend.infer_dynamic, np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)), soft_embeddings= koboldai_vars.sp, soft_tokens= soft_tokens, @@ -7886,16 +7933,6 @@ def final_startup(): excluded_world_info= list(set() for _ in range(koboldai_vars.numseqs)) ) else: - #threading.Thread( - # target=tpu_mtj_backend.infer_static, - # args=(np.uint32((23403, 727, 20185)),), - # kwargs={ - # "soft_embeddings": koboldai_vars.sp, - # "soft_tokens": soft_tokens, - # "gen_len": 1, - # "numseqs": koboldai_vars.numseqs, - # }, - #).start() tpool.execute( tpu_mtj_backend.infer_static, np.uint32((23403, 727, 20185)), @@ -9634,6 +9671,12 @@ def log_image_generation( with open(db_path, "w") as file: json.dump(j, file, indent="\t") +@socketio.on("retry_generated_image") +@logger.catch +def UI2_retry_generated_image(): + eventlet.sleep(0) + generate_story_image(koboldai_vars.picture_prompt) + def generate_story_image( prompt: str, file_prefix: str = "image", @@ -9683,9 +9726,6 @@ def generate_story_image( b64_data = base64.b64encode(buffer.getvalue()).decode("ascii") koboldai_vars.picture = b64_data - - - def generate_image(prompt: str) -> Optional[Image.Image]: if koboldai_vars.img_gen_priority == 4: diff --git a/environments/huggingface.yml b/environments/huggingface.yml index 9fd2ce5f..485ac338 100644 --- a/environments/huggingface.yml +++ b/environments/huggingface.yml @@ -22,7 +22,6 @@ dependencies: - loguru - termcolor - Pillow - - diffusers - psutil - pip: - flask-cloudflared @@ -40,4 +39,5 @@ dependencies: - ijson - bitsandbytes - ftfy - - pydub \ No newline at end of file + - pydub + - diffusers diff --git a/environments/rocm.yml b/environments/rocm.yml index ad39a3a6..a0e23177 100644 --- a/environments/rocm.yml +++ b/environments/rocm.yml @@ -19,12 +19,10 @@ dependencies: - loguru - termcolor - Pillow - - diffusers - psutil - pip: - --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 - - torch - - torchvision + - torch==1.11.* - flask-cloudflared - flask-ngrok - lupa==1.10 @@ -38,3 +36,4 @@ dependencies: - ijson - ftfy - pydub + - diffusers diff --git a/koboldai_settings.py b/koboldai_settings.py index 520b59ca..47246cc5 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -7,15 +7,12 @@ import shutil from typing import List, Union from io import BytesIO from flask import has_request_context, session -from flask_socketio import SocketIO, join_room, leave_room +from flask_socketio import join_room, leave_room from collections import OrderedDict import multiprocessing from logger import logger -import eventlet import torch import numpy as np -import inspect -import ctypes import random serverstarted = False diff --git a/static/koboldai.css b/static/koboldai.css index 947771bf..64827944 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -279,11 +279,6 @@ border-top-right-radius: var(--tabs_rounding); margin: 2px; } -#biasing { - margin-left: 5px; - margin-right: 5px; -} - .setting_container.var_sync_alt_system_alt_gen[system_alt_gen="true"] { display: none; } @@ -708,85 +703,79 @@ border-top-right-radius: var(--tabs_rounding); } /* Bias */ -#biases_label { - cursor: pointer; +#biasing { + padding: 0px 10px; } -.bias { - display: grid; - grid-template-areas: "phrase percent max"; - grid-template-columns: auto 100px 100px; +.bias_card { + background-color: var(--setting_background); + padding: 8px; + margin-bottom: 12px; +} + +.bias_top { + display: flex; + justify-content: space-between; +} + +.bias_top > .close_button { + display: flex; + justify-content: center; + align-items: center; + cursor: pointer; + margin-left: 7px; +} + +.bias_card input { + width: 100%; + border: none; } .bias_phrase { - grid-area: phrase; - margin-right: 5px; + height: 36px; } -.bias_phrase input { - width: 100%; +.bias_slider_labels { + display: flex; + justify-content: space-between; + position: relative; } -.bias_score { - grid-area: percent; - margin-right: 5px; -} - -.bias_comp_threshold { - grid-area: max; - margin-right: 5px; -} - -.bias_slider { - display: grid; - grid-template-areas: "bar bar bar" - "min cur max"; - grid-template-columns: 33px 34px 33px; -} - -.bias_slider_bar { - grid-area: bar; -} - -.bias_slider_min { - grid-area: min; +.bias_slider_min, .bias_slider_max{ font-size: calc(0.8em + var(--font_size_adjustment)); - margin-right: 5px; - margin-left: 5px; + user-select: none; } .bias_slider_cur { - grid-area: cur; text-align: center; + outline: none; } -.bias_slider_max { - grid-area: max; - font-size: calc(0.8em + var(--font_size_adjustment)); - text-align: right; - margin-right: 5px; - margin-left: 5px; +.bias_score, .bias_top { + margin-bottom: 12px; } -.bias_header { - display: grid; - grid-template-areas: "phrase percent max"; - grid-template-columns: auto 100px 100px; +.bias_slider_centerlayer { + /* Yeah its a bit of a hack */ + position: absolute; + left: 0px; + width: 100%; + display: flex; + justify-content: center; } -.bias_header_phrase { - grid-area: phrase; - font-size: calc(1.1em + var(--font_size_adjustment)); +#bias-add { + width: 100%; + cursor: pointer; + display: flex; + justify-content: center; + align-items: center; + padding: 4px; } -.bias_header_score { - grid-area: percent; - font-size: calc(1.1em + var(--font_size_adjustment)); -} - -.bias_header_max { - grid-area: max; - font-size: calc(1.1em + var(--font_size_adjustment)); +#bias-add:hover { + width: 100%; + background-color: rgba(255, 255, 255, 0.05); } /* Theme */ diff --git a/static/koboldai.js b/static/koboldai.js index 042e3e61..f367d684 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -56,6 +56,7 @@ var shift_down = false; var world_info_data = {}; var world_info_folder_data = {}; var saved_settings = {}; +var biases_data = {}; var finder_selection_index = -1; var colab_cookies = null; var wi_finder_data = []; @@ -2862,36 +2863,43 @@ function save_as_story(response) { if (response === "overwrite?") openPopup("save-confirm"); } -function save_bias(item) { - - var have_blank = false; +function save_bias() { var biases = {}; //get all of our biases - for (bias of document.getElementsByClassName("bias")) { + + for (const biasCard of document.getElementsByClassName("bias_card")) { //phrase - var phrase = bias.querySelector(".bias_phrase").querySelector("input").value; + var phrase = biasCard.querySelector(".bias_phrase").value; + if (!phrase) continue; //Score - var percent = parseFloat(bias.querySelector(".bias_score").querySelector("input").value); + var score = parseFloat(biasCard.querySelector(".bias_score input").value); //completion threshold - var comp_threshold = parseInt(bias.querySelector(".bias_comp_threshold").querySelector("input").value); + var compThreshold = parseInt(biasCard.querySelector(".bias_comp_threshold input").value); - if (phrase != "") { - biases[phrase] = [percent, comp_threshold]; - } - bias.classList.add("pulse"); + biases[phrase] = [score, compThreshold]; } - + + // Because of course JS couldn't just support comparison in a core type + // that would be silly and foolish + if (JSON.stringify(biases) === JSON.stringify(biases_data)) { + // No changes. :( + return; + } + + biases_data = biases; + console.info("saving biases", biases) + //send the biases to the backend socket.emit("phrase_bias_update", biases); - } function sync_to_server(item) { //get value - value = null; - name = null; + let value = null; + let name = null; + if ((item.tagName.toLowerCase() === 'checkbox') || (item.tagName.toLowerCase() === 'input') || (item.tagName.toLowerCase() === 'select') || (item.tagName.toLowerCase() == 'textarea')) { if (item.getAttribute("type") == "checkbox") { value = item.checked; @@ -3293,7 +3301,9 @@ function finished_tts() { } else { action = document.getElementById("Selected Text Chunk "+(next_action-1)); } - action.classList.remove("tts_playing"); + if (action) { + action.classList.remove("tts_playing"); + } if (next_action <= action_count) { document.getElementById("reader").src = "/audio?id="+next_action; document.getElementById("reader").setAttribute("action_id", next_action); @@ -3311,7 +3321,9 @@ function tts_playing() { } else { action = document.getElementById("Selected Text Chunk "+action_id); } - action.classList.add("tts_playing"); + if (action) { + action.classList.add("tts_playing"); + } } function view_selection_probabilities() { @@ -3632,42 +3644,156 @@ function options_on_right(data) { } } -function do_biases(data) { - //console.log(data); - //clear out our old bias lines - let bias_list = Object.assign([], document.getElementsByClassName("bias")); - for (item of bias_list) { - //console.log(item); - item.parentNode.removeChild(item); +function makeBiasCard(phrase="", score=0, compThreshold=10) { + function updateBias(origin, input, save=true) { + const textInput = input.closest(".bias_slider").querySelector(".bias_slider_cur"); + let value = (origin === "slider") ? input.value : parseFloat(textInput.innerText); + textInput.innerText = value; + input.value = value; + + // Only save on "commitful" actions like blur or mouseup to not spam + // the poor server + if (save) save_bias(); } - + + const biasContainer = $el("#bias-container"); + const biasCard = $el(".bias_card.template").cloneNode(true); + biasCard.classList.remove("template"); + + const closeButton = biasCard.querySelector(".close_button"); + closeButton.addEventListener("click", function(event) { + biasCard.remove(); + + // We just deleted the last bias, we probably don't want to keep seeing + // them pop up + if (!biasContainer.firstChild) biasContainer.setAttribute( + "please-stop-adding-biases-whenever-i-delete-them", + "i mean it" + ); + save_bias(); + }); + + const phraseInput = biasCard.querySelector(".bias_phrase"); + phraseInput.addEventListener("change", save_bias); + + const scoreInput = biasCard.querySelector(".bias_score input"); + const compThresholdInput = biasCard.querySelector(".bias_comp_threshold input"); + + phraseInput.value = phrase; + scoreInput.value = score; + compThresholdInput.value = compThreshold; + + for (const input of [scoreInput, compThresholdInput]) { + // Init sync + updateBias("slider", input, false); + + // Visual update on each value change + input.addEventListener("input", function() { updateBias("slider", this, false) }); + + // Only when we leave do we sync to server + input.addEventListener("change", function() { updateBias("slider", this) }); + + // Personally I don't want to press a key 100 times to add one + const nudge = parseFloat(input.getAttribute("keyboard-step") ?? input.getAttribute("step")); + const min = parseFloat(input.getAttribute("min")); + const max = parseFloat(input.getAttribute("max")); + + const currentHitbox = input.closest(".hitbox"); + const currentLabel = input.closest(".bias_slider").querySelector(".bias_slider_cur"); + + // TODO: Prevent paste of just non-number characters + currentLabel.addEventListener("paste", function(event) { event.preventDefault(); }) + + currentLabel.addEventListener("keydown", function(event) { + // Nothing special for numbers + if ( + [".", "-", "ArrowLeft", "ArrowRight", "Backspace", "Delete"].includes(event.key) + || event.ctrlKey + || (parseInt(event.key) || parseInt(event.key) === 0) + ) return; + + // Either we are special keys or forbidden keys + event.preventDefault(); + + switch (event.key) { + case "Enter": + currentLabel.blur(); + break; + // This feels very nice :^) + case "ArrowDown": + case "ArrowUp": + let delta = (event.key === "ArrowUp") ? nudge : -nudge; + let currentValue = parseFloat(currentLabel.innerText); + + event.preventDefault(); + if (!currentValue && currentValue !== 0) return; + + // toFixed because 1.1 + 0.1 !== 1.2 yay rounding errors. + // Although the added decimal place(s) look cleaner now + // that I think about it. + let value = Math.min(max, Math.max(min, currentValue + delta)); + currentLabel.innerText = value.toFixed(2); + + updateBias("text", input, false); + break; + } + }); + + currentHitbox.addEventListener("wheel", function(event) { + // Only when focused! (May drop this requirement later, browsers seem to behave when scrolling :] ) + if (currentLabel !== document.activeElement) return; + if (event.deltaY === 0) return; + + let delta = (event.deltaY > 0) ? -nudge : nudge; + let currentValue = parseFloat(currentLabel.innerText); + + event.preventDefault(); + if (!currentValue && currentValue !== 0) return; + let value = Math.min(max, Math.max(min, currentValue + delta)); + currentLabel.innerText = value.toFixed(2); + + updateBias("text", input, false); + }); + + currentLabel.addEventListener("blur", function(event) { + updateBias("text", input); + }); + } + + biasContainer.appendChild(biasCard); + return biasCard; +} +$el("#bias-add").addEventListener("click", function(event) { + const card = makeBiasCard(); + card.querySelector(".bias_phrase").focus(); +}); + +function do_biases(data) { + console.info("Taking inventory of biases") + biases_data = data.value; + + // Clear out our old bias cards, weird recursion because remove sometimes + // doesn't work (???) + const biasContainer = $el("#bias-container"); + for (let i=0;i<10000;i++) { + if (!biasContainer.firstChild) break; + biasContainer.firstChild.remove(); + } + if(biasContainer.firstChild) reportError("We are doomed", "Undead zombie bias, please report this"); + //add our bias lines for (const [key, value] of Object.entries(data.value)) { - bias_line = document.getElementById("empty_bias").cloneNode(true); - bias_line.id = ""; - bias_line.classList.add("bias"); - bias_line.querySelector(".bias_phrase").querySelector("input").value = key; - bias_line.querySelector(".bias_score").querySelector("input").value = value[0]; - update_bias_slider_value(bias_line.querySelector(".bias_score").querySelector("input")); - bias_line.querySelector(".bias_comp_threshold").querySelector("input").value = value[1]; - update_bias_slider_value(bias_line.querySelector(".bias_comp_threshold").querySelector("input")); - document.getElementById('biasing').append(bias_line); + makeBiasCard(key, value[0], value[1]); } - - //add another bias line if this is the phrase and it's not blank - bias_line = document.getElementById("empty_bias").cloneNode(true); - bias_line.id = ""; - bias_line.classList.add("bias"); - bias_line.querySelector(".bias_phrase").querySelector("input").value = ""; - bias_line.querySelector(".bias_phrase").querySelector("input").id = "empty_bias_phrase"; - bias_line.querySelector(".bias_score").querySelector("input").value = 1; - bias_line.querySelector(".bias_comp_threshold").querySelector("input").value = 50; - document.getElementById('biasing').append(bias_line); + + // Add seed card if we have no bias cards and we didn't just delete the + // last bias card + if ( + !biasContainer.firstChild && + !biasContainer.getAttribute("please-stop-adding-biases-whenever-i-delete-them") + ) makeBiasCard(); } -function update_bias_slider_value(slider) { - slider.parentElement.parentElement.querySelector(".bias_slider_cur").textContent = slider.value; -} function distortColor(rgb) { // rgb are 0..255, NOT NORMALIZED!!!!!! @@ -6672,7 +6798,7 @@ function imgGenRetry() { const image = $el(".action_image"); if (!image) return; $el("#image-loading").classList.remove("hidden"); - socket.emit("generate_image", {'action_id': image.getAttribute("chunk")}); + socket.emit("retry_generated_image"); } /* Genres */ diff --git a/templates/settings flyout.html b/templates/settings flyout.html index 3ac450d0..b5307461 100644 --- a/templates/settings flyout.html +++ b/templates/settings flyout.html @@ -244,14 +244,8 @@