diff --git a/aiserver.py b/aiserver.py index db96e9e5..c4b82b3a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -7586,6 +7586,19 @@ def UI2_clear_generated_image(data): koboldai_vars.picture = "" koboldai_vars.picture_prompt = "" +#==================================================================# +# Retrieve previous images +#==================================================================# +@socketio.on("get_story_image") +@logger.catch +def UI_2_get_story_image(data): + action_id = data['action_id'] + (filename, prompt) = koboldai_vars.actions.get_picture(action_id) + print(filename) + if filename is not None: + with open(filename, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + #@logger.catch def get_items_locations_from_text(text): # load model and tokenizer @@ -7921,10 +7934,70 @@ def UI_2_audio(): start_time = time.time() while not os.path.exists(filename) and time.time()-start_time < 60: #Waiting up to 60 seconds for the file to be generated time.sleep(0.1) - return send_file( - filename, - mimetype="audio/ogg") + if os.path.exists(filename): + return send_file( + filename, + mimetype="audio/ogg") + show_error_notification("Error generating audio chunk", f"Something happened. Maybe check the log?") +#==================================================================# +# Download complete audio file +#==================================================================# +@socketio.on("gen_full_audio") +def UI_2_gen_full_audio(data): + from pydub import AudioSegment + if args.no_ui: + return redirect('/api/latest') + + logger.info("Generating complete audio file") + combined_audio = None + complete_filename = os.path.join(koboldai_vars.save_paths.generated_audio, "complete.ogg") + for action_id in range(-1, koboldai_vars.actions.action_count+1): + filename = os.path.join(koboldai_vars.save_paths.generated_audio, f"{action_id}.ogg") + filename_slow = os.path.join(koboldai_vars.save_paths.generated_audio, f"{action_id}_slow.ogg") + + + if os.path.exists(filename_slow): + if combined_audio is None: + combined_audio = AudioSegment.from_file(filename_slow, format="ogg") + else: + combined_audio = combined_audio + AudioSegment.from_file(filename_slow, format="ogg") + elif os.path.exists(filename): + if combined_audio is None: + combined_audio = AudioSegment.from_file(filename, format="ogg") + else: + combined_audio = combined_audio + AudioSegment.from_file(filename, format="ogg") + else: + logger.info("Action {} has no audio. Generating now".format(action_id)) + koboldai_vars.actions.gen_audio(action_id) + while not os.path.exists(filename) and time.time()-start_time < 60: #Waiting up to 60 seconds for the file to be generated + time.sleep(0.1) + if os.path.exists(filename): + if combined_audio is None: + combined_audio = AudioSegment.from_file(filename, format="ogg") + else: + combined_audio = combined_audio + AudioSegment.from_file(filename, format="ogg") + else: + show_error_notification("Error generating audio chunk", f"Something happened. Maybe check the log?") + + logger.info("Sending audio file") + file_handle = combined_audio.export(complete_filename, format="ogg") + return True + + +@app.route("/audio_full") +@require_allowed_ip +@logger.catch +def UI_2_audio_full(): + logger.info("Downloading complete audio file") + complete_filename = os.path.join(koboldai_vars.save_paths.generated_audio, "complete.ogg") + if os.path.exists(complete_filename): + return send_file( + complete_filename, + as_attachment=True, + download_name = koboldai_vars.story_name, + mimetype="audio/ogg") + #==================================================================# # Download of the image for an action diff --git a/environments/huggingface.yml b/environments/huggingface.yml index 39e3c2d1..deead8e7 100644 --- a/environments/huggingface.yml +++ b/environments/huggingface.yml @@ -28,6 +28,7 @@ dependencies: - termcolor - Pillow - psutil + - ffmpeg - pip: - flask-cloudflared==0.0.10 - flask-ngrok @@ -65,3 +66,4 @@ dependencies: - pynvml - xformers==0.0.21 - https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.0/flash_attn-2.3.0+cu118torch2.0cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux' + - omegaconf diff --git a/environments/ipex.yml b/environments/ipex.yml index 9b54ceb2..c9794e48 100644 --- a/environments/ipex.yml +++ b/environments/ipex.yml @@ -22,6 +22,7 @@ dependencies: - termcolor - Pillow - psutil + - ffmpeg - pip: - --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ - torch==2.0.1a0; sys_platform == 'linux' @@ -58,4 +59,5 @@ dependencies: - https://github.com/0cc4m/exllama/releases/download/0.0.7/exllama-0.0.7-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux' - https://github.com/0cc4m/exllama/releases/download/0.0.7/exllama-0.0.7-cp38-cp38-win_amd64.whl; sys_platform == 'win32' - windows-curses; sys_platform == 'win32' - - pynvml \ No newline at end of file + - pynvml + - omegaconf \ No newline at end of file diff --git a/environments/rocm.yml b/environments/rocm.yml index b85f32d4..2a6043ea 100644 --- a/environments/rocm.yml +++ b/environments/rocm.yml @@ -22,6 +22,7 @@ dependencies: - termcolor - Pillow - psutil + - ffmpeg - pip: - --extra-index-url https://download.pytorch.org/whl/rocm5.2 - torch==1.13.1+rocm5.2 @@ -47,4 +48,5 @@ dependencies: - peft==0.3.0 - windows-curses; sys_platform == 'win32' - pynvml - - https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+rocm5.4.2-cp38-cp38-linux_x86_64.whl \ No newline at end of file + - https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+rocm5.4.2-cp38-cp38-linux_x86_64.whl + - omegaconf \ No newline at end of file diff --git a/gensettings.py b/gensettings.py index 9b69af43..d1aa22d1 100644 --- a/gensettings.py +++ b/gensettings.py @@ -795,8 +795,6 @@ gensettingstf = [ "sub_path": "UI", "classname": "story", "name": "gen_audio", - "extra_classes": "var_sync_alt_system_experimental_features" - , "ui_level": 1 }, { diff --git a/koboldai_settings.py b/koboldai_settings.py index 9d2f5f4d..159031ea 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -21,6 +21,7 @@ queue = None multi_story = False global enable_whitelist enable_whitelist = False +slow_tts_message_shown = False if importlib.util.find_spec("tortoise") is not None: from tortoise import api @@ -100,6 +101,28 @@ def process_variable_changes(socketio, classname, name, value, old_value, debug_ else: socketio.emit("var_changed", {"classname": classname, "name": name, "old_value": "*" * len(old_value) if old_value is not None else "", "value": "*" * len(value) if value is not None else "", "transmit_time": transmit_time}, include_self=True, broadcast=True, room=room) +def basic_send(socketio, classname, event, data): + #Get which room we'll send the messages to + global multi_story + if multi_story: + if classname != 'story': + room = 'UI_2' + else: + if has_request_context(): + room = 'default' if 'story' not in session else session['story'] + else: + logger.error("We tried to access the story register outside of an http context. Will not work in multi-story mode") + return + else: + room = "UI_2" + if not has_request_context(): + if queue is not None: + #logger.debug("Had to use queue") + queue.put([event, data, {"broadcast":True, "room":room}]) + else: + if socketio is not None: + socketio.emit(event, data, include_self=True, broadcast=True, room=room) + class koboldai_vars(object): def __init__(self, socketio): self._model_settings = model_settings(socketio, self) @@ -1420,6 +1443,7 @@ class KoboldStoryRegister(object): self.make_audio_thread_slow = None self.make_audio_queue_slow = multiprocessing.Queue() self.probability_buffer = None + self.audio_status = {} for item in sequence: self.append(item) @@ -1577,6 +1601,13 @@ class KoboldStoryRegister(object): if "Original Text" not in json_data["actions"][item]: json_data["actions"][item]["Original Text"] = json_data["actions"][item]["Selected Text"] + + if "audio_gen" not in json_data["actions"][item]: + json_data["actions"][item]["audio_gen"] = 0 + + if "image_gen" not in json_data["actions"][item]: + json_data["actions"][item]["image_gen"] = False + temp[int(item)] = json_data['actions'][item] if int(item) >= self.action_count-100: #sending last 100 items to UI @@ -2056,9 +2087,14 @@ class KoboldStoryRegister(object): return action_text_split def gen_audio(self, action_id=None, overwrite=True): - if self.story_settings.gen_audio and self._koboldai_vars.experimental_features: - if action_id is None: - action_id = self.action_count + if action_id is None: + action_id = self.action_count + if overwrite: + if action_id != -1: + self.actions[action_id]["audio_gen"] = 0 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) + if self.story_settings.gen_audio: + if self.tts_model is None: language = 'en' @@ -2075,33 +2111,41 @@ class KoboldStoryRegister(object): if overwrite or not os.path.exists(filename): if action_id == -1: - self.make_audio_queue.put((self._koboldai_vars.prompt, filename)) + self.make_audio_queue.put((self._koboldai_vars.prompt, filename, action_id)) else: - self.make_audio_queue.put((self.actions[action_id]['Selected Text'], filename)) - if self.make_audio_thread_slow is None or not self.make_audio_thread_slow.is_alive(): - self.make_audio_thread_slow = threading.Thread(target=self.create_wave_slow, args=(self.make_audio_queue_slow, )) - self.make_audio_thread_slow.start() + self.make_audio_queue.put((self.actions[action_id]['Selected Text'], filename, action_id)) + if self.make_audio_thread is None or not self.make_audio_thread.is_alive(): + self.make_audio_thread = threading.Thread(target=self.create_wave, args=(self.make_audio_queue, )) + self.make_audio_thread.start() + elif not overwrite and os.path.exists(filename): + if action_id != -1: + self.actions[action_id]["audio_gen"] = 1 if overwrite or not os.path.exists(filename_slow): if action_id == -1: - self.make_audio_queue_slow.put((self._koboldai_vars.prompt, filename_slow)) + self.make_audio_queue_slow.put((self._koboldai_vars.prompt, filename_slow, action_id)) else: - self.make_audio_queue_slow.put((self.actions[action_id]['Selected Text'], filename_slow)) + self.make_audio_queue_slow.put((self.actions[action_id]['Selected Text'], filename_slow, action_id)) if self.make_audio_thread_slow is None or not self.make_audio_thread_slow.is_alive(): self.make_audio_thread_slow = threading.Thread(target=self.create_wave_slow, args=(self.make_audio_queue_slow, )) self.make_audio_thread_slow.start() + elif not overwrite and os.path.exists(filename_slow): + if action_id != -1: + self.actions[action_id]["audio_gen"] = 2 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) + def create_wave(self, make_audio_queue): import pydub sample_rate = 24000 speaker = 'en_5' while not make_audio_queue.empty(): - (text, filename) = make_audio_queue.get() + (text, filename, action_id) = make_audio_queue.get() logger.info("Creating audio for {}".format(os.path.basename(filename))) if text.strip() == "": shutil.copy("data/empty_audio.ogg", filename) else: - if len(text) > 2000: + if len(text) > 1000: text = self.sentence_re.findall(text) else: text = [text] @@ -2116,27 +2160,44 @@ class KoboldStoryRegister(object): output = pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels) else: output = output + pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels) - output.export(filename, format="ogg", bitrate="16k") + if output is not None: + output.export(filename, format="ogg", bitrate="16k") + if action_id != -1 and self.actions[action_id]["audio_gen"] == 0: + self.actions[action_id]["audio_gen"] = 1 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) def create_wave_slow(self, make_audio_queue_slow): import pydub + global slow_tts_message_shown sample_rate = 24000 speaker = 'train_daws' + if importlib.util.find_spec("tortoise") is None and not slow_tts_message_shown: + logger.info("Disabling slow (and higher quality) tts as it's not installed") + slow_tts_message_shown=True if self.tortoise is None and importlib.util.find_spec("tortoise") is not None: - self.tortoise=api.TextToSpeech() + self.tortoise=api.TextToSpeech(use_deepspeed=os.environ.get('deepspeed', "false").lower()=="true", kv_cache=os.environ.get('kv_cache', "true").lower()=="true", half=True) if importlib.util.find_spec("tortoise") is not None: voice_samples, conditioning_latents = load_voices([speaker]) while not make_audio_queue_slow.empty(): start_time = time.time() - (text, filename) = make_audio_queue_slow.get() + (text, filename, action_id) = make_audio_queue_slow.get() text_length = len(text) logger.info("Creating audio for {}".format(os.path.basename(filename))) if text.strip() == "": shutil.copy("data/empty_audio.ogg", filename) else: - if len(text) > 20000: + if len(self.tortoise.tokenizer.encode(text)) > 400: text = self.sentence_re.findall(text) + i=0 + while i <= len(text)-2: + if len(self.tortoise.tokenizer.encode(text[i] + text[i+1])) < 400: + text[i] = text[i] + text[i+1] + del text[i+1] + else: + i+=1 + + else: text = [text] output = None @@ -2147,11 +2208,16 @@ class KoboldStoryRegister(object): output = pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels) else: output = output + pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels) - output.export(filename, format="ogg", bitrate="16k") + if output is not None: + output.export(filename, format="ogg", bitrate="16k") + if action_id != -1: + self.actions[action_id]["audio_gen"] = 2 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) logger.info("Slow audio took {} for {} characters".format(time.time()-start_time, text_length)) def gen_all_audio(self, overwrite=False): - if self.story_settings.gen_audio and self._koboldai_vars.experimental_features: + if self.story_settings.gen_audio: + logger.info("Generating audio for any missing actions") for i in reversed([-1]+list(self.actions.keys())): self.gen_audio(i, overwrite=False) #else: diff --git a/requirements.txt b/requirements.txt index 7eb3b66c..f668dd88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,4 +51,4 @@ pynvml https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.0/flash_attn-2.3.0+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; sys_platform == 'linux' and python_version == '3.10' https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.0/flash_attn-2.3.0+cu118torch2.0cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux' and python_version == '3.8' xformers==0.0.21 -exllamav2==0.0.4 \ No newline at end of file +omegaconf diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 710d6cef..1e30b1cb 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -34,4 +34,5 @@ flask_compress ijson ftfy pydub -sentencepiece \ No newline at end of file +sentencepiece +omegaconf \ No newline at end of file diff --git a/static/koboldai.css b/static/koboldai.css index a419a4f3..9ee96976 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -1932,7 +1932,7 @@ body { } .tts_controls.hidden[story_gen_audio="true"] { - display: inherit !important; + display: flex !important; } .inputrow .tts_controls { @@ -1942,6 +1942,49 @@ body { width: 100%; text-align: center; overflow: hidden; + flex-direction: row; +} +.inputrow .tts_controls div { + padding: 0px; + height: 100%; + width: 100%; + text-align: center; + overflow: hidden; + display: flex; + flex-direction: column; + /*flex-basis: 100%;*/ +} + +.audio_status_action { + flex-basis: 100%; +} + +.audio_status_action[status="2"] { + background-color: green; +} + +.audio_status_action[status="1"] { + background-color: yellow; +} + +.audio_status_action[status="0"] { + background-color: red; +} + +.audio_status_action[status="-1"] { + display: none; +} + +.inputrow .tts_controls div button { + flex-basis: 100%; +} + +.inputrow .tts_controls .audio_status { + padding: 0px; + height: 100%; + width: 2px; + display: flex; + flex-direction: column; } .inputrow .back { diff --git a/static/koboldai.js b/static/koboldai.js index 038b6e87..3762be6c 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -45,6 +45,7 @@ socket.on("show_error_notification", function(data) { reportError(data.title, da socket.on("generated_wi", showGeneratedWIData); socket.on("stream_tokens", stream_tokens); socket.on("show_options", show_options); +socket.on("set_audio_status", set_audio_status); //socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});}); // Must be done before any elements are made; we track their changes. @@ -148,6 +149,7 @@ const context_menu_actions = { {label: "Add to World Info Entry", icon: "auto_stories", enabledOn: "SELECTION", click: push_selection_to_world_info}, {label: "Add as Bias", icon: "insights", enabledOn: "SELECTION", click: push_selection_to_phrase_bias}, {label: "Retry from here", icon: "refresh", enabledOn: "CARET", click: retry_from_here}, + {label: "Generate image for here", icon: "image", enabledOn: "CARET", click: generate_image}, null, {label: "Take Screenshot", icon: "screenshot_monitor", enabledOn: "SELECTION", click: screenshot_selection}, // Not implemented! See view_selection_probabiltiies @@ -637,6 +639,7 @@ function process_actions_data(data) { actions_data[parseInt(action.id)] = action.action; do_story_text_updates(action); create_options(action); + set_audio_status(action); } clearTimeout(game_text_scroll_timeout); @@ -648,6 +651,26 @@ function process_actions_data(data) { } +function set_audio_status(action) { + if (!('audio_gen' in action.action)) { + action.action.audio_gen = 0; + } + if (!(document.getElementById("audio_gen_status_"+action.id))) { + sp = document.createElement("SPAN"); + sp.id = "audio_gen_status_"+action.id + sp.classList.add("audio_status_action"); + sp.setAttribute("status", -1); + document.getElementById("audio_status").appendChild(sp); + } + document.getElementById("audio_gen_status_"+action.id).setAttribute("status", action.action.audio_gen); + + //Delete empty actions + if (action.action['Selected Text'] == "") { + console.log("disabling status"); + document.getElementById("audio_gen_status_"+action.id).setAttribute("status", -1); + } +} + function parseChatMessages(text) { let messages = []; @@ -703,9 +726,9 @@ function do_story_text_updates(action) { item.classList.add("rawtext"); item.setAttribute("chunk", action.id); item.setAttribute("tabindex", parseInt(action.id)+1); - //item.addEventListener("focus", (event) => { - // set_edit(event.target); - //}); + item.addEventListener("focus", (event) => { + set_image_action(action.id); + }); //need to find the closest element closest_element = document.getElementById("story_prompt"); @@ -1387,6 +1410,29 @@ function redrawPopup() { } this.parentElement.classList.add("selected"); }; + td.ondblclick = function () { + let accept = document.getElementById("popup_accept"); + if (this.getAttribute("valid") == "true") { + accept.classList.remove("disabled"); + accept.disabled = false; + accept.setAttribute("selected_value", this.id); + socket.emit(document.getElementById("popup_accept").getAttribute("emit"), this.id); + closePopups(); + } else { + accept.setAttribute("selected_value", ""); + accept.classList.add("disabled"); + accept.disabled = true; + if (this.getAttribute("folder") == "true") { + socket.emit("popup_change_folder", this.id); + } + } + + let popup_list = document.getElementById('popup_list').getElementsByClassName("selected"); + for (item of popup_list) { + item.classList.remove("selected"); + } + this.parentElement.classList.add("selected"); + }; tr.append(td); } @@ -3420,7 +3466,7 @@ function fix_dirty_game_text() { if (dirty_chunks.includes("game_text")) { dirty_chunks = dirty_chunks.filter(item => item != "game_text"); - console.log("Firing Fix messed up text"); + //console.log("Firing Fix messed up text"); //Fixing text outside of chunks for (node of game_text.childNodes) { if ((!(node instanceof HTMLElement) || !node.hasAttribute("chunk")) && (node.textContent.trim() != "")) { @@ -3767,6 +3813,21 @@ function stop_tts() { } } +function download_tts() { + document.getElementById("download_tts").innerText = "hourglass_empty"; + socket.emit("gen_full_audio", {}, download_actual_file_tts); +} + +function download_actual_file_tts(data) { + if (data) { + var link = document.createElement("a"); + link.download = document.getElementsByClassName("var_sync_story_story_name ")[0].text+".ogg"; + link.href = "/audio_full"; + link.click(); + document.getElementById("download_tts").innerText = "download"; + } +} + function finished_tts() { next_action = parseInt(document.getElementById("reader").getAttribute("action_id"))+1; action_count = parseInt(document.getElementById("action_count").textContent); @@ -3800,6 +3861,29 @@ function tts_playing() { } } +function set_image_action(action_id) { + console.log(action_id); + socket.emit("get_story_image", {action_id: action_id}, change_image); +} + +function change_image(data) { + image_area = document.getElementById("action image"); + + let maybeImage = image_area.getElementsByClassName("action_image")[0]; + if (maybeImage) maybeImage.remove(); + + $el("#image-loading").classList.add("hidden"); + + if (data != undefined) { + var image = new Image(); + image.src = 'data:image/png;base64,'+data; + image.classList.add("action_image"); + image.setAttribute("context-menu", "generated-image"); + image.addEventListener("click", imgGenView); + image_area.appendChild(image); + } +} + function view_selection_probabilities() { // Not quite sure how this should work yet. Probabilities are obviously on // the token level, which we have no UI representation of. There are other @@ -7027,6 +7111,22 @@ $el("#generate-image-button").addEventListener("click", function() { socket.emit("generate_image", {}); }); +function generate_image() { + let chunk = null; + for (element of document.getElementsByClassName("editing")) { + if (element.id == 'story_prompt') { + chunk = -1 + } else { + chunk = parseInt(element.id.split(" ").at(-1)); + } + } + if (chunk != null) { + socket.emit("generate_image", {action_id: chunk}); + } + + +} + /* -- Shiny New Chat -- */ function addMessage(author, content, actionId, afterMsgEl=null, time=null) { if (!time) time = Number(new Date()); diff --git a/templates/index_new.html b/templates/index_new.html index 0170d2aa..80bdff7c 100644 --- a/templates/index_new.html +++ b/templates/index_new.html @@ -54,7 +54,7 @@
-
+ @@ -107,8 +107,13 @@