diff --git a/koboldai_settings.py b/koboldai_settings.py index f1ea538e..6a7ef81c 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -101,6 +101,28 @@ def process_variable_changes(socketio, classname, name, value, old_value, debug_ else: socketio.emit("var_changed", {"classname": classname, "name": name, "old_value": "*" * len(old_value) if old_value is not None else "", "value": "*" * len(value) if value is not None else "", "transmit_time": transmit_time}, include_self=True, broadcast=True, room=room) +def basic_send(socketio, classname, event, data): + #Get which room we'll send the messages to + global multi_story + if multi_story: + if classname != 'story': + room = 'UI_2' + else: + if has_request_context(): + room = 'default' if 'story' not in session else session['story'] + else: + logger.error("We tried to access the story register outside of an http context. Will not work in multi-story mode") + return + else: + room = "UI_2" + if not has_request_context(): + if queue is not None: + #logger.debug("Had to use queue") + queue.put([event, data, {"broadcast":True, "room":room}]) + else: + if socketio is not None: + socketio.emit(event, data, include_self=True, broadcast=True, room=room) + class koboldai_vars(object): def __init__(self, socketio): self._model_settings = model_settings(socketio, self) @@ -1419,6 +1441,7 @@ class KoboldStoryRegister(object): self.make_audio_thread_slow = None self.make_audio_queue_slow = multiprocessing.Queue() self.probability_buffer = None + self.audio_status = {} for item in sequence: self.append(item) @@ -1576,6 +1599,13 @@ class KoboldStoryRegister(object): if "Original Text" not in json_data["actions"][item]: json_data["actions"][item]["Original Text"] = json_data["actions"][item]["Selected Text"] + + if "audio_gen" not in json_data["actions"][item]: + json_data["actions"][item]["audio_gen"] = 0 + + if "image_gen" not in json_data["actions"][item]: + json_data["actions"][item]["image_gen"] = False + temp[int(item)] = json_data['actions'][item] if int(item) >= self.action_count-100: #sending last 100 items to UI @@ -2055,9 +2085,14 @@ class KoboldStoryRegister(object): return action_text_split def gen_audio(self, action_id=None, overwrite=True): + if action_id is None: + action_id = self.action_count + if overwrite: + if action_id != -1: + self.actions[action_id]["audio_gen"] = 0 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) if self.story_settings.gen_audio: - if action_id is None: - action_id = self.action_count + if self.tts_model is None: language = 'en' @@ -2074,28 +2109,36 @@ class KoboldStoryRegister(object): if overwrite or not os.path.exists(filename): if action_id == -1: - self.make_audio_queue.put((self._koboldai_vars.prompt, filename)) + self.make_audio_queue.put((self._koboldai_vars.prompt, filename, action_id)) else: - self.make_audio_queue.put((self.actions[action_id]['Selected Text'], filename)) + self.make_audio_queue.put((self.actions[action_id]['Selected Text'], filename, action_id)) if self.make_audio_thread is None or not self.make_audio_thread.is_alive(): self.make_audio_thread = threading.Thread(target=self.create_wave, args=(self.make_audio_queue, )) self.make_audio_thread.start() + elif not overwrite and os.path.exists(filename): + if action_id != -1: + self.actions[action_id]["audio_gen"] = 1 if overwrite or not os.path.exists(filename_slow): if action_id == -1: - self.make_audio_queue_slow.put((self._koboldai_vars.prompt, filename_slow)) + self.make_audio_queue_slow.put((self._koboldai_vars.prompt, filename_slow, action_id)) else: - self.make_audio_queue_slow.put((self.actions[action_id]['Selected Text'], filename_slow)) + self.make_audio_queue_slow.put((self.actions[action_id]['Selected Text'], filename_slow, action_id)) if self.make_audio_thread_slow is None or not self.make_audio_thread_slow.is_alive(): self.make_audio_thread_slow = threading.Thread(target=self.create_wave_slow, args=(self.make_audio_queue_slow, )) self.make_audio_thread_slow.start() + elif not overwrite and os.path.exists(filename_slow): + if action_id != -1: + self.actions[action_id]["audio_gen"] = 2 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) + def create_wave(self, make_audio_queue): import pydub sample_rate = 24000 speaker = 'en_5' while not make_audio_queue.empty(): - (text, filename) = make_audio_queue.get() + (text, filename, action_id) = make_audio_queue.get() logger.info("Creating audio for {}".format(os.path.basename(filename))) if text.strip() == "": shutil.copy("data/empty_audio.ogg", filename) @@ -2117,6 +2160,9 @@ class KoboldStoryRegister(object): output = output + pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels) if output is not None: output.export(filename, format="ogg", bitrate="16k") + if action_id != -1 and self.actions[action_id]["audio_gen"] == 0: + self.actions[action_id]["audio_gen"] = 1 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) def create_wave_slow(self, make_audio_queue_slow): import pydub @@ -2133,7 +2179,7 @@ class KoboldStoryRegister(object): voice_samples, conditioning_latents = load_voices([speaker]) while not make_audio_queue_slow.empty(): start_time = time.time() - (text, filename) = make_audio_queue_slow.get() + (text, filename, action_id) = make_audio_queue_slow.get() text_length = len(text) logger.info("Creating audio for {}".format(os.path.basename(filename))) if text.strip() == "": @@ -2162,6 +2208,9 @@ class KoboldStoryRegister(object): output = output + pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels) if output is not None: output.export(filename, format="ogg", bitrate="16k") + if action_id != -1: + self.actions[action_id]["audio_gen"] = 2 + basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]}) logger.info("Slow audio took {} for {} characters".format(time.time()-start_time, text_length)) def gen_all_audio(self, overwrite=False): diff --git a/static/koboldai.css b/static/koboldai.css index a419a4f3..9ee96976 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -1932,7 +1932,7 @@ body { } .tts_controls.hidden[story_gen_audio="true"] { - display: inherit !important; + display: flex !important; } .inputrow .tts_controls { @@ -1942,6 +1942,49 @@ body { width: 100%; text-align: center; overflow: hidden; + flex-direction: row; +} +.inputrow .tts_controls div { + padding: 0px; + height: 100%; + width: 100%; + text-align: center; + overflow: hidden; + display: flex; + flex-direction: column; + /*flex-basis: 100%;*/ +} + +.audio_status_action { + flex-basis: 100%; +} + +.audio_status_action[status="2"] { + background-color: green; +} + +.audio_status_action[status="1"] { + background-color: yellow; +} + +.audio_status_action[status="0"] { + background-color: red; +} + +.audio_status_action[status="-1"] { + display: none; +} + +.inputrow .tts_controls div button { + flex-basis: 100%; +} + +.inputrow .tts_controls .audio_status { + padding: 0px; + height: 100%; + width: 2px; + display: flex; + flex-direction: column; } .inputrow .back { diff --git a/static/koboldai.js b/static/koboldai.js index 971e1788..fa0180dc 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -45,6 +45,7 @@ socket.on("show_error_notification", function(data) { reportError(data.title, da socket.on("generated_wi", showGeneratedWIData); socket.on("stream_tokens", stream_tokens); socket.on("show_options", show_options); +socket.on("set_audio_status", set_audio_status); //socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});}); // Must be done before any elements are made; we track their changes. @@ -601,6 +602,7 @@ function create_options(action) { function process_actions_data(data) { start_time = Date.now(); + console.log(data); if (Array.isArray(data.value)) { actions = data.value; } else { @@ -637,6 +639,7 @@ function process_actions_data(data) { actions_data[parseInt(action.id)] = action.action; do_story_text_updates(action); create_options(action); + set_audio_status(action); } clearTimeout(game_text_scroll_timeout); @@ -648,6 +651,27 @@ function process_actions_data(data) { } +function set_audio_status(action) { + if (!('audio_gen' in action.action)) { + action.action.audio_gen = 0; + } + if (!(document.getElementById("audio_gen_status_"+action.id))) { + sp = document.createElement("SPAN"); + sp.id = "audio_gen_status_"+action.id + sp.classList.add("audio_status_action"); + sp.setAttribute("status", -1); + document.getElementById("audio_status").appendChild(sp); + } + document.getElementById("audio_gen_status_"+action.id).setAttribute("status", action.action.audio_gen); + + //Delete empty actions + if (action.action['Selected Text'] == "") { + console.log("disabling status"); + document.getElementById("audio_gen_status_"+action.id).setAttribute("status", -1); + } + console.log("Setting " + action.id + " to " + action.action.audio_gen); +} + function parseChatMessages(text) { let messages = []; diff --git a/templates/index_new.html b/templates/index_new.html index e1931078..920471ec 100644 --- a/templates/index_new.html +++ b/templates/index_new.html @@ -107,9 +107,13 @@ + + + play_arrow stop download + Submit