From fd695977c63f56a9715c2f3cad01c5ebfd326083 Mon Sep 17 00:00:00 2001 From: ebolam Date: Tue, 20 Sep 2022 09:21:57 -0400 Subject: [PATCH] Fix for world info highlighting --- aiserver.py | 44 ++++++- environments/huggingface.yml | 1 + environments/rocm.yml | 1 + koboldai_settings.py | 2 + requirements.txt | 3 +- requirements_mtj.txt | 3 +- static/koboldai.css | 4 + static/koboldai.js | 222 +++++++++++++++-------------------- 8 files changed, 148 insertions(+), 132 deletions(-) diff --git a/aiserver.py b/aiserver.py index 1b2e87d7..6315b3cd 100644 --- a/aiserver.py +++ b/aiserver.py @@ -8180,7 +8180,7 @@ def UI_2_generate_image(data): #If we have > 4 keys, use those otherwise use sumarization if len(keys) < 4: from transformers import pipeline as summary_pipeline - summarizer = summary_pipeline("summarization") + summarizer = summary_pipeline("summarization", model="sshleifer/distilbart-cnn-12-6") #text to summarize: if len(koboldai_vars.actions) < 5: text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt]) @@ -8191,18 +8191,54 @@ def UI_2_generate_image(data): transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']] transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp + del summarizer art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting', - b64_data = text2img(", ".join(keys), art_guide = art_guide) - emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)}) + #If we don't have a GPU, use horde + if not koboldai_vars.hascuda: + b64_data = text2img(", ".join(keys), art_guide = art_guide) + else: + if torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0) >= 6000000000: + #He have enough vram, just do it locally + b64_data = text2img_local(", ".join(keys), art_guide = art_guide) + elif torch.cuda.get_device_properties(0).total_memory > 6000000000: + #We could do it locally by swapping the model out + print("Could do local or online") + else: + b64_data = text2img_local(", ".join(keys), art_guide = art_guide) + koboldai_vars.picture = b64_data + koboldai_vars.picture_prompt = ", ".join(keys) + #emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)}) +@logger.catch +def text2img_local(prompt, art_guide="", filename="new.png"): + start_time = time.time() + print("Generating Image") + koboldai_vars.generating_image = True + from diffusers import StableDiffusionPipeline + import base64 + from io import BytesIO + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, cache="./stable-diffusion-v1-4").to("cuda") + print("time to load: {}".format(time.time() - start_time)) + from torch import autocast + with autocast("cuda"): + image = pipe(prompt)["sample"][0] + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode('ascii') + print("time to generate: {}".format(time.time() - start_time)) + pipe.to("cpu") + koboldai_vars.generating_image = False + print("time to unload: {}".format(time.time() - start_time)) + return img_str + @logger.catch def text2img(prompt, art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting', filename = "story_art.png"): - print("Generating Image") + print("Generating Image using Horde") koboldai_vars.generating_image = True final_imgen_params = { "n": 1, diff --git a/environments/huggingface.yml b/environments/huggingface.yml index 72871c81..25706993 100644 --- a/environments/huggingface.yml +++ b/environments/huggingface.yml @@ -21,6 +21,7 @@ dependencies: - apispec-webframeworks - loguru - Pillow + - diffusers - pip: - flask-cloudflared - flask-ngrok diff --git a/environments/rocm.yml b/environments/rocm.yml index a0334a9a..c6bf06d2 100644 --- a/environments/rocm.yml +++ b/environments/rocm.yml @@ -18,6 +18,7 @@ dependencies: - apispec-webframeworks - loguru - Pillow + - diffusers - pip: - --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html - torch==1.10.* diff --git a/koboldai_settings.py b/koboldai_settings.py index 83f538a7..b30c69ca 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -576,6 +576,8 @@ class story_settings(settings): self.context = [] self.last_story_load = None self.revisions = [] + self.picture = "" #base64 of the image shown for the story + self.picture_prompt = "" #Prompt used to create picture #must be at bottom self.no_save = False #Temporary disable save (doesn't save with the file) diff --git a/requirements.txt b/requirements.txt index c9b316c0..02972a2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ flask_session marshmallow>=3.13 apispec-webframeworks loguru -Pillow \ No newline at end of file +Pillow +diffusers \ No newline at end of file diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 4c1c4ca7..12af394d 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -20,4 +20,5 @@ flask-session marshmallow>=3.13 apispec-webframeworks loguru -Pillow \ No newline at end of file +Pillow +diffusers \ No newline at end of file diff --git a/static/koboldai.css b/static/koboldai.css index c7ed66b7..c75968e1 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -2205,6 +2205,10 @@ button.disabled { font-style: italic; } +.wi_match { + font-style: italic; +} + .within_max_length { color: var(--text_to_ai_color); font-weight: bold; diff --git a/static/koboldai.js b/static/koboldai.js index 5c103df9..b4ee3e8d 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -29,7 +29,6 @@ socket.on('load_cookies', function(data){load_cookies(data)}); socket.on('load_tweaks', function(data){load_tweaks(data);}); socket.on("wi_results", updateWISearchListings); socket.on("request_prompt_config", configurePrompt); -socket.on("Action_Image", function(data){Action_Image(data);}); //socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});}); var presets = {}; @@ -520,6 +519,22 @@ function var_changed(data) { } else { button.childNodes[1].textContent = "Story"; } + //Special Case for story picture + } else if (data.classname == "story" && data.name == "picture") { + image_area = document.getElementById("action image"); + while (image_area.firstChild) { + image_area.removeChild(image_area.firstChild); + } + if (data.value != "") { + var image = new Image(); + image.src = 'data:image/png;base64,'+data.value; + image.classList.add("action_image"); + image_area.appendChild(image); + } + } else if (data.classname == "story" && data.name == "picture_prompt") { + if (document.getElementById("action image").firstChild) { + document.getElementById("action image").firstChild.setAttribute("title", data.value); + } //Basic Data Syncing } else { var elements_to_change = document.getElementsByClassName("var_sync_"+data.classname.replace(" ", "_")+"_"+data.name.replace(" ", "_")); @@ -1959,18 +1974,6 @@ function load_cookies(data) { } } -function Action_Image(data) { - var image = new Image(); - image.src = 'data:image/png;base64,'+data['b64']; - image.setAttribute("title", data['prompt']); - image.classList.add("action_image"); - image_area = document.getElementById("action image"); - while (image_area.firstChild) { - image_area.removeChild(image_area.firstChild); - } - image_area.appendChild(image); -} - //--------------------------------------------UI to Server Functions---------------------------------- function unload_userscripts() { files_to_unload = document.getElementById('loaded_userscripts'); @@ -2877,9 +2880,8 @@ function assign_world_info_to_action(action_item, uid) { for (action of actions) { //First check to see if we have a key in the text - var words = action.textContent.split(" "); for (const [key, worldinfo] of Object.entries(worldinfo_to_check)) { - //remove any world info tags + //remove any world info tags on the overall chunk for (tag of action.getElementsByClassName("tag_uid_"+uid)) { tag.classList.remove("tag_uid_"+uid); tag.removeAttribute("title"); @@ -2893,120 +2895,13 @@ function assign_world_info_to_action(action_item, uid) { if (worldinfo['keysecondary'].length > 0) { for (second_key of worldinfo['keysecondary']) { if (action.textContent.replace(/[^0-9a-z \'\"]/gi, '').includes(second_key)) { - //First let's assign our world info id to the action so we know to count the tokens for the world info - current_ids = action.getAttribute("world_info_uids")?action.getAttribute("world_info_uids").split(','):[]; - if (!(current_ids.includes(uid))) { - current_ids.push(uid); - } - action.setAttribute("world_info_uids", current_ids.join(",")); - //OK we have the phrase in our action. Let's see if we can identify the word(s) that are triggering - for (var i = 0; i < words.length; i++) { - key_words = keyword.split(" ").length; - var to_check = words.slice(i, i+key_words).join("").replace(/[^0-9a-z \'\"]/gi, '').trim(); - if (keyword == to_check) { - var start_word = i; - var end_word = i+len_of_keyword; - var passed_words = 0; - for (span of action.childNodes) { - if (passed_words + span.textContent.split(" ").length < start_word) { - passed_words += span.textContent.trim().split(" ").length; - } else if (passed_words < end_word) { - //OK, we have text that matches, let's do the highlighting - //we can skip the highlighting if it's already done though - if (span.tagName != "I") { - var span_text = span.textContent.trim().split(" "); - var before_highlight_text = span_text.slice(0, start_word-passed_words).join(" ")+" "; - var highlight_text = span_text.slice(start_word-passed_words, end_word-passed_words).join(" "); - if (end_word-passed_words <= span_text.length) { - highlight_text += " "; - } - var after_highlight_text = span_text.slice((end_word-passed_words)).join(" "); - //console.log(span.textContent); - //console.log(keyword); - //console.log(before_highlight_text); - //console.log(highlight_text); - //console.log(after_highlight_text); - //console.log("passed: "+passed_words+" start:" + start_word + " end: "+end_word+" continue: "+(end_word-passed_words)); - //console.log(null); - var before_span = document.createElement("span"); - before_span.textContent = before_highlight_text; - var hightlight_span = document.createElement("span"); - hightlight_span.classList.add("italics"); - hightlight_span.textContent = highlight_text; - hightlight_span.title = worldinfo['content']; - var after_span = document.createElement("span"); - after_span.textContent = after_highlight_text; - action.insertBefore(before_span, span); - action.insertBefore(hightlight_span, span); - action.insertBefore(after_span, span); - span.remove(); - } - passed_words += span.textContent.trim().split(" ").length; - } - } - } - } + highlight_world_info_text_in_chunk(action, worldinfo); + break; } } } else { - //First let's assign our world info id to the action so we know to count the tokens for the world info - current_ids = action.getAttribute("world_info_uids")?action.getAttribute("world_info_uids").split(','):[]; - if (!(current_ids.includes(uid))) { - current_ids.push(uid); - } - action.setAttribute("world_info_uids", current_ids.join(",")); - //OK we have the phrase in our action. Let's see if we can identify the word(s) that are triggering - var len_of_keyword = keyword.split(" ").length; - //go through each word to see where we get a match - for (var i = 0; i < words.length; i++) { - //get the words from the ith word to the i+len_of_keyword. Get rid of non-letters/numbers/'/" - var to_check = words.slice(i, i+len_of_keyword).join(" ").replace(/[^0-9a-z \'\"]/gi, '').trim(); - if (keyword == to_check) { - var start_word = i; - var end_word = i+len_of_keyword; - var passed_words = 0; - for (span of action.childNodes) { - if (passed_words + span.textContent.split(" ").length < start_word) { - passed_words += span.textContent.trim().split(" ").length; - } else if (passed_words < end_word) { - //OK, we have text that matches, let's do the highlighting - //we can skip the highlighting if it's already done though - if (span.tagName != "I") { - var span_text = span.textContent.trim().split(" "); - var before_highlight_text = span_text.slice(0, start_word-passed_words).join(" ")+" "; - var highlight_text = span_text.slice(start_word-passed_words, end_word-passed_words).join(" "); - if (end_word-passed_words <= span_text.length) { - highlight_text += " "; - } - var after_highlight_text = span_text.slice((end_word-passed_words)).join(" ")+" "; - if (after_highlight_text[0] == ' ') { - after_highlight_text = after_highlight_text.substring(1); - } - //console.log("'"+span.textContent+"'"); - //console.log(keyword); - //console.log("'"+before_highlight_text+"'"); - //console.log("'"+highlight_text+"'"); - //console.log("'"+after_highlight_text+"'"); - //console.log("passed: "+passed_words+" start:" + start_word + " end: "+end_word+" continue: "+(end_word-passed_words)); - //console.log(null); - var before_span = document.createElement("span"); - before_span.textContent = before_highlight_text; - var hightlight_span = document.createElement("span"); - hightlight_span.classList.add("italics"); - hightlight_span.textContent = highlight_text; - hightlight_span.title = worldinfo['content']; - var after_span = document.createElement("span"); - after_span.textContent = after_highlight_text; - action.insertBefore(before_span, span); - action.insertBefore(hightlight_span, span); - action.insertBefore(after_span, span); - span.remove(); - } - passed_words += span.textContent.trim().split(" ").length; - } - } - } - } + highlight_world_info_text_in_chunk(action, worldinfo); + break; } } @@ -3016,6 +2911,81 @@ function assign_world_info_to_action(action_item, uid) { } } +function highlight_world_info_text_in_chunk(action, wi) { + //First let's assign our world info id to the action so we know to count the tokens for the world info + let uid = wi['uid']; + let words = action.textContent.split(" "); + current_ids = action.getAttribute("world_info_uids")?action.getAttribute("world_info_uids").split(','):[]; + if (!(current_ids.includes(uid))) { + current_ids.push(uid); + } + action.setAttribute("world_info_uids", current_ids.join(",")); + //OK we have the phrase in our action. + //First let's find the largest key that matches + let largest_key = ""; + for (keyword of wi['key']) { + if ((keyword.length > largest_key.length) && (action.textContent.replace(/[^0-9a-z \'\"]/gi, '').includes(keyword))) { + largest_key = keyword; + } + } + //console.log(largest_key); + + + //Let's see if we can identify the word(s) that are triggering + var len_of_keyword = largest_key.split(" ").length; + //go through each word to see where we get a match + for (var i = 0; i < words.length; i++) { + //get the words from the ith word to the i+len_of_keyword. Get rid of non-letters/numbers/'/" + var to_check = words.slice(i, i+len_of_keyword).join(" ").replace(/[^0-9a-z \'\"]/gi, '').trim(); + if (largest_key == to_check) { + var start_word = i; + var end_word = i+len_of_keyword; + var passed_words = 0; + for (span of action.childNodes) { + if (passed_words + span.textContent.split(" ").length < start_word) { + passed_words += span.textContent.trim().split(" ").length; + } else if (passed_words < end_word) { + //OK, we have text that matches, let's do the highlighting + //we can skip the highlighting if it's already done though + if (~(span.classList.contains('wi_match'))) { + var span_text = span.textContent.trim().split(" "); + var before_highlight_text = span_text.slice(0, start_word-passed_words).join(" ")+" "; + var highlight_text = span_text.slice(start_word-passed_words, end_word-passed_words).join(" "); + if (end_word-passed_words <= span_text.length) { + highlight_text += " "; + } + var after_highlight_text = span_text.slice((end_word-passed_words)).join(" ")+" "; + if (after_highlight_text[0] == ' ') { + after_highlight_text = after_highlight_text.substring(1); + } + if (before_highlight_text != "") { + //console.log("'"+before_highlight_text+"'"); + var before_span = document.createElement("span"); + before_span.textContent = before_highlight_text; + action.insertBefore(before_span, span); + } + //console.log("'"+highlight_text+"'"); + var hightlight_span = document.createElement("span"); + hightlight_span.classList.add("wi_match"); + hightlight_span.textContent = highlight_text; + hightlight_span.title = wi['content']; + action.insertBefore(hightlight_span, span); + if (after_highlight_text != "") { + //console.log("'"+after_highlight_text+"'"); + var after_span = document.createElement("span"); + after_span.textContent = after_highlight_text; + action.insertBefore(after_span, span); + } + //console.log("Done"); + span.remove(); + } + passed_words += span.textContent.trim().split(" ").length; + } + } + } + } +} + function update_token_lengths() { clearTimeout(calc_token_usage_timeout); calc_token_usage_timeout = setTimeout(calc_token_usage, 200);