mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix for world info highlighting
This commit is contained in:
42
aiserver.py
42
aiserver.py
@@ -8180,7 +8180,7 @@ def UI_2_generate_image(data):
|
||||
#If we have > 4 keys, use those otherwise use sumarization
|
||||
if len(keys) < 4:
|
||||
from transformers import pipeline as summary_pipeline
|
||||
summarizer = summary_pipeline("summarization")
|
||||
summarizer = summary_pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
||||
#text to summarize:
|
||||
if len(koboldai_vars.actions) < 5:
|
||||
text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt])
|
||||
@@ -8191,18 +8191,54 @@ def UI_2_generate_image(data):
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
|
||||
keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']]
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
|
||||
del summarizer
|
||||
|
||||
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
|
||||
|
||||
#If we don't have a GPU, use horde
|
||||
if not koboldai_vars.hascuda:
|
||||
b64_data = text2img(", ".join(keys), art_guide = art_guide)
|
||||
emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
|
||||
else:
|
||||
if torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0) >= 6000000000:
|
||||
#He have enough vram, just do it locally
|
||||
b64_data = text2img_local(", ".join(keys), art_guide = art_guide)
|
||||
elif torch.cuda.get_device_properties(0).total_memory > 6000000000:
|
||||
#We could do it locally by swapping the model out
|
||||
print("Could do local or online")
|
||||
else:
|
||||
b64_data = text2img_local(", ".join(keys), art_guide = art_guide)
|
||||
koboldai_vars.picture = b64_data
|
||||
koboldai_vars.picture_prompt = ", ".join(keys)
|
||||
#emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
|
||||
|
||||
|
||||
@logger.catch
|
||||
def text2img_local(prompt, art_guide="", filename="new.png"):
|
||||
start_time = time.time()
|
||||
print("Generating Image")
|
||||
koboldai_vars.generating_image = True
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import base64
|
||||
from io import BytesIO
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, cache="./stable-diffusion-v1-4").to("cuda")
|
||||
print("time to load: {}".format(time.time() - start_time))
|
||||
from torch import autocast
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
print("time to generate: {}".format(time.time() - start_time))
|
||||
pipe.to("cpu")
|
||||
koboldai_vars.generating_image = False
|
||||
print("time to unload: {}".format(time.time() - start_time))
|
||||
return img_str
|
||||
|
||||
@logger.catch
|
||||
def text2img(prompt,
|
||||
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
|
||||
filename = "story_art.png"):
|
||||
print("Generating Image")
|
||||
print("Generating Image using Horde")
|
||||
koboldai_vars.generating_image = True
|
||||
final_imgen_params = {
|
||||
"n": 1,
|
||||
|
@@ -21,6 +21,7 @@ dependencies:
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- Pillow
|
||||
- diffusers
|
||||
- pip:
|
||||
- flask-cloudflared
|
||||
- flask-ngrok
|
||||
|
@@ -18,6 +18,7 @@ dependencies:
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- Pillow
|
||||
- diffusers
|
||||
- pip:
|
||||
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
|
||||
- torch==1.10.*
|
||||
|
@@ -576,6 +576,8 @@ class story_settings(settings):
|
||||
self.context = []
|
||||
self.last_story_load = None
|
||||
self.revisions = []
|
||||
self.picture = "" #base64 of the image shown for the story
|
||||
self.picture_prompt = "" #Prompt used to create picture
|
||||
|
||||
#must be at bottom
|
||||
self.no_save = False #Temporary disable save (doesn't save with the file)
|
||||
|
@@ -17,3 +17,4 @@ marshmallow>=3.13
|
||||
apispec-webframeworks
|
||||
loguru
|
||||
Pillow
|
||||
diffusers
|
@@ -21,3 +21,4 @@ marshmallow>=3.13
|
||||
apispec-webframeworks
|
||||
loguru
|
||||
Pillow
|
||||
diffusers
|
@@ -2205,6 +2205,10 @@ button.disabled {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.wi_match {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.within_max_length {
|
||||
color: var(--text_to_ai_color);
|
||||
font-weight: bold;
|
||||
|
@@ -29,7 +29,6 @@ socket.on('load_cookies', function(data){load_cookies(data)});
|
||||
socket.on('load_tweaks', function(data){load_tweaks(data);});
|
||||
socket.on("wi_results", updateWISearchListings);
|
||||
socket.on("request_prompt_config", configurePrompt);
|
||||
socket.on("Action_Image", function(data){Action_Image(data);});
|
||||
//socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});});
|
||||
|
||||
var presets = {};
|
||||
@@ -520,6 +519,22 @@ function var_changed(data) {
|
||||
} else {
|
||||
button.childNodes[1].textContent = "Story";
|
||||
}
|
||||
//Special Case for story picture
|
||||
} else if (data.classname == "story" && data.name == "picture") {
|
||||
image_area = document.getElementById("action image");
|
||||
while (image_area.firstChild) {
|
||||
image_area.removeChild(image_area.firstChild);
|
||||
}
|
||||
if (data.value != "") {
|
||||
var image = new Image();
|
||||
image.src = 'data:image/png;base64,'+data.value;
|
||||
image.classList.add("action_image");
|
||||
image_area.appendChild(image);
|
||||
}
|
||||
} else if (data.classname == "story" && data.name == "picture_prompt") {
|
||||
if (document.getElementById("action image").firstChild) {
|
||||
document.getElementById("action image").firstChild.setAttribute("title", data.value);
|
||||
}
|
||||
//Basic Data Syncing
|
||||
} else {
|
||||
var elements_to_change = document.getElementsByClassName("var_sync_"+data.classname.replace(" ", "_")+"_"+data.name.replace(" ", "_"));
|
||||
@@ -1959,18 +1974,6 @@ function load_cookies(data) {
|
||||
}
|
||||
}
|
||||
|
||||
function Action_Image(data) {
|
||||
var image = new Image();
|
||||
image.src = 'data:image/png;base64,'+data['b64'];
|
||||
image.setAttribute("title", data['prompt']);
|
||||
image.classList.add("action_image");
|
||||
image_area = document.getElementById("action image");
|
||||
while (image_area.firstChild) {
|
||||
image_area.removeChild(image_area.firstChild);
|
||||
}
|
||||
image_area.appendChild(image);
|
||||
}
|
||||
|
||||
//--------------------------------------------UI to Server Functions----------------------------------
|
||||
function unload_userscripts() {
|
||||
files_to_unload = document.getElementById('loaded_userscripts');
|
||||
@@ -2877,9 +2880,8 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
|
||||
for (action of actions) {
|
||||
//First check to see if we have a key in the text
|
||||
var words = action.textContent.split(" ");
|
||||
for (const [key, worldinfo] of Object.entries(worldinfo_to_check)) {
|
||||
//remove any world info tags
|
||||
//remove any world info tags on the overall chunk
|
||||
for (tag of action.getElementsByClassName("tag_uid_"+uid)) {
|
||||
tag.classList.remove("tag_uid_"+uid);
|
||||
tag.removeAttribute("title");
|
||||
@@ -2893,75 +2895,49 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
if (worldinfo['keysecondary'].length > 0) {
|
||||
for (second_key of worldinfo['keysecondary']) {
|
||||
if (action.textContent.replace(/[^0-9a-z \'\"]/gi, '').includes(second_key)) {
|
||||
//First let's assign our world info id to the action so we know to count the tokens for the world info
|
||||
current_ids = action.getAttribute("world_info_uids")?action.getAttribute("world_info_uids").split(','):[];
|
||||
if (!(current_ids.includes(uid))) {
|
||||
current_ids.push(uid);
|
||||
}
|
||||
action.setAttribute("world_info_uids", current_ids.join(","));
|
||||
//OK we have the phrase in our action. Let's see if we can identify the word(s) that are triggering
|
||||
for (var i = 0; i < words.length; i++) {
|
||||
key_words = keyword.split(" ").length;
|
||||
var to_check = words.slice(i, i+key_words).join("").replace(/[^0-9a-z \'\"]/gi, '').trim();
|
||||
if (keyword == to_check) {
|
||||
var start_word = i;
|
||||
var end_word = i+len_of_keyword;
|
||||
var passed_words = 0;
|
||||
for (span of action.childNodes) {
|
||||
if (passed_words + span.textContent.split(" ").length < start_word) {
|
||||
passed_words += span.textContent.trim().split(" ").length;
|
||||
} else if (passed_words < end_word) {
|
||||
//OK, we have text that matches, let's do the highlighting
|
||||
//we can skip the highlighting if it's already done though
|
||||
if (span.tagName != "I") {
|
||||
var span_text = span.textContent.trim().split(" ");
|
||||
var before_highlight_text = span_text.slice(0, start_word-passed_words).join(" ")+" ";
|
||||
var highlight_text = span_text.slice(start_word-passed_words, end_word-passed_words).join(" ");
|
||||
if (end_word-passed_words <= span_text.length) {
|
||||
highlight_text += " ";
|
||||
}
|
||||
var after_highlight_text = span_text.slice((end_word-passed_words)).join(" ");
|
||||
//console.log(span.textContent);
|
||||
//console.log(keyword);
|
||||
//console.log(before_highlight_text);
|
||||
//console.log(highlight_text);
|
||||
//console.log(after_highlight_text);
|
||||
//console.log("passed: "+passed_words+" start:" + start_word + " end: "+end_word+" continue: "+(end_word-passed_words));
|
||||
//console.log(null);
|
||||
var before_span = document.createElement("span");
|
||||
before_span.textContent = before_highlight_text;
|
||||
var hightlight_span = document.createElement("span");
|
||||
hightlight_span.classList.add("italics");
|
||||
hightlight_span.textContent = highlight_text;
|
||||
hightlight_span.title = worldinfo['content'];
|
||||
var after_span = document.createElement("span");
|
||||
after_span.textContent = after_highlight_text;
|
||||
action.insertBefore(before_span, span);
|
||||
action.insertBefore(hightlight_span, span);
|
||||
action.insertBefore(after_span, span);
|
||||
span.remove();
|
||||
}
|
||||
passed_words += span.textContent.trim().split(" ").length;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
highlight_world_info_text_in_chunk(action, worldinfo);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
highlight_world_info_text_in_chunk(action, worldinfo);
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function highlight_world_info_text_in_chunk(action, wi) {
|
||||
//First let's assign our world info id to the action so we know to count the tokens for the world info
|
||||
let uid = wi['uid'];
|
||||
let words = action.textContent.split(" ");
|
||||
current_ids = action.getAttribute("world_info_uids")?action.getAttribute("world_info_uids").split(','):[];
|
||||
if (!(current_ids.includes(uid))) {
|
||||
current_ids.push(uid);
|
||||
}
|
||||
action.setAttribute("world_info_uids", current_ids.join(","));
|
||||
//OK we have the phrase in our action. Let's see if we can identify the word(s) that are triggering
|
||||
var len_of_keyword = keyword.split(" ").length;
|
||||
//OK we have the phrase in our action.
|
||||
//First let's find the largest key that matches
|
||||
let largest_key = "";
|
||||
for (keyword of wi['key']) {
|
||||
if ((keyword.length > largest_key.length) && (action.textContent.replace(/[^0-9a-z \'\"]/gi, '').includes(keyword))) {
|
||||
largest_key = keyword;
|
||||
}
|
||||
}
|
||||
//console.log(largest_key);
|
||||
|
||||
|
||||
//Let's see if we can identify the word(s) that are triggering
|
||||
var len_of_keyword = largest_key.split(" ").length;
|
||||
//go through each word to see where we get a match
|
||||
for (var i = 0; i < words.length; i++) {
|
||||
//get the words from the ith word to the i+len_of_keyword. Get rid of non-letters/numbers/'/"
|
||||
var to_check = words.slice(i, i+len_of_keyword).join(" ").replace(/[^0-9a-z \'\"]/gi, '').trim();
|
||||
if (keyword == to_check) {
|
||||
if (largest_key == to_check) {
|
||||
var start_word = i;
|
||||
var end_word = i+len_of_keyword;
|
||||
var passed_words = 0;
|
||||
@@ -2971,7 +2947,7 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
} else if (passed_words < end_word) {
|
||||
//OK, we have text that matches, let's do the highlighting
|
||||
//we can skip the highlighting if it's already done though
|
||||
if (span.tagName != "I") {
|
||||
if (~(span.classList.contains('wi_match'))) {
|
||||
var span_text = span.textContent.trim().split(" ");
|
||||
var before_highlight_text = span_text.slice(0, start_word-passed_words).join(" ")+" ";
|
||||
var highlight_text = span_text.slice(start_word-passed_words, end_word-passed_words).join(" ");
|
||||
@@ -2982,24 +2958,25 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
if (after_highlight_text[0] == ' ') {
|
||||
after_highlight_text = after_highlight_text.substring(1);
|
||||
}
|
||||
//console.log("'"+span.textContent+"'");
|
||||
//console.log(keyword);
|
||||
if (before_highlight_text != "") {
|
||||
//console.log("'"+before_highlight_text+"'");
|
||||
//console.log("'"+highlight_text+"'");
|
||||
//console.log("'"+after_highlight_text+"'");
|
||||
//console.log("passed: "+passed_words+" start:" + start_word + " end: "+end_word+" continue: "+(end_word-passed_words));
|
||||
//console.log(null);
|
||||
var before_span = document.createElement("span");
|
||||
before_span.textContent = before_highlight_text;
|
||||
action.insertBefore(before_span, span);
|
||||
}
|
||||
//console.log("'"+highlight_text+"'");
|
||||
var hightlight_span = document.createElement("span");
|
||||
hightlight_span.classList.add("italics");
|
||||
hightlight_span.classList.add("wi_match");
|
||||
hightlight_span.textContent = highlight_text;
|
||||
hightlight_span.title = worldinfo['content'];
|
||||
hightlight_span.title = wi['content'];
|
||||
action.insertBefore(hightlight_span, span);
|
||||
if (after_highlight_text != "") {
|
||||
//console.log("'"+after_highlight_text+"'");
|
||||
var after_span = document.createElement("span");
|
||||
after_span.textContent = after_highlight_text;
|
||||
action.insertBefore(before_span, span);
|
||||
action.insertBefore(hightlight_span, span);
|
||||
action.insertBefore(after_span, span);
|
||||
}
|
||||
//console.log("Done");
|
||||
span.remove();
|
||||
}
|
||||
passed_words += span.textContent.trim().split(" ").length;
|
||||
@@ -3009,13 +2986,6 @@ function assign_world_info_to_action(action_item, uid) {
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function update_token_lengths() {
|
||||
clearTimeout(calc_token_usage_timeout);
|
||||
calc_token_usage_timeout = setTimeout(calc_token_usage, 200);
|
||||
|
Reference in New Issue
Block a user