Merge pull request #476 from ebolam/Image_Gen

Image gen Enhancements
This commit is contained in:
henk717 2023-11-03 14:43:17 +01:00 committed by GitHub
commit 0a38167d1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 325 additions and 33 deletions

View File

@ -7586,6 +7586,19 @@ def UI2_clear_generated_image(data):
koboldai_vars.picture = ""
koboldai_vars.picture_prompt = ""
#==================================================================#
# Retrieve previous images
#==================================================================#
@socketio.on("get_story_image")
@logger.catch
def UI_2_get_story_image(data):
action_id = data['action_id']
(filename, prompt) = koboldai_vars.actions.get_picture(action_id)
print(filename)
if filename is not None:
with open(filename, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
#@logger.catch
def get_items_locations_from_text(text):
# load model and tokenizer
@ -7921,10 +7934,70 @@ def UI_2_audio():
start_time = time.time()
while not os.path.exists(filename) and time.time()-start_time < 60: #Waiting up to 60 seconds for the file to be generated
time.sleep(0.1)
return send_file(
filename,
mimetype="audio/ogg")
if os.path.exists(filename):
return send_file(
filename,
mimetype="audio/ogg")
show_error_notification("Error generating audio chunk", f"Something happened. Maybe check the log?")
#==================================================================#
# Download complete audio file
#==================================================================#
@socketio.on("gen_full_audio")
def UI_2_gen_full_audio(data):
from pydub import AudioSegment
if args.no_ui:
return redirect('/api/latest')
logger.info("Generating complete audio file")
combined_audio = None
complete_filename = os.path.join(koboldai_vars.save_paths.generated_audio, "complete.ogg")
for action_id in range(-1, koboldai_vars.actions.action_count+1):
filename = os.path.join(koboldai_vars.save_paths.generated_audio, f"{action_id}.ogg")
filename_slow = os.path.join(koboldai_vars.save_paths.generated_audio, f"{action_id}_slow.ogg")
if os.path.exists(filename_slow):
if combined_audio is None:
combined_audio = AudioSegment.from_file(filename_slow, format="ogg")
else:
combined_audio = combined_audio + AudioSegment.from_file(filename_slow, format="ogg")
elif os.path.exists(filename):
if combined_audio is None:
combined_audio = AudioSegment.from_file(filename, format="ogg")
else:
combined_audio = combined_audio + AudioSegment.from_file(filename, format="ogg")
else:
logger.info("Action {} has no audio. Generating now".format(action_id))
koboldai_vars.actions.gen_audio(action_id)
while not os.path.exists(filename) and time.time()-start_time < 60: #Waiting up to 60 seconds for the file to be generated
time.sleep(0.1)
if os.path.exists(filename):
if combined_audio is None:
combined_audio = AudioSegment.from_file(filename, format="ogg")
else:
combined_audio = combined_audio + AudioSegment.from_file(filename, format="ogg")
else:
show_error_notification("Error generating audio chunk", f"Something happened. Maybe check the log?")
logger.info("Sending audio file")
file_handle = combined_audio.export(complete_filename, format="ogg")
return True
@app.route("/audio_full")
@require_allowed_ip
@logger.catch
def UI_2_audio_full():
logger.info("Downloading complete audio file")
complete_filename = os.path.join(koboldai_vars.save_paths.generated_audio, "complete.ogg")
if os.path.exists(complete_filename):
return send_file(
complete_filename,
as_attachment=True,
download_name = koboldai_vars.story_name,
mimetype="audio/ogg")
#==================================================================#
# Download of the image for an action

View File

@ -28,6 +28,7 @@ dependencies:
- termcolor
- Pillow
- psutil
- ffmpeg
- pip:
- flask-cloudflared==0.0.10
- flask-ngrok
@ -65,3 +66,4 @@ dependencies:
- pynvml
- xformers==0.0.21
- https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.0/flash_attn-2.3.0+cu118torch2.0cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux'
- omegaconf

View File

@ -22,6 +22,7 @@ dependencies:
- termcolor
- Pillow
- psutil
- ffmpeg
- pip:
- --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
- torch==2.0.1a0; sys_platform == 'linux'
@ -58,4 +59,5 @@ dependencies:
- https://github.com/0cc4m/exllama/releases/download/0.0.7/exllama-0.0.7-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux'
- https://github.com/0cc4m/exllama/releases/download/0.0.7/exllama-0.0.7-cp38-cp38-win_amd64.whl; sys_platform == 'win32'
- windows-curses; sys_platform == 'win32'
- pynvml
- pynvml
- omegaconf

View File

@ -22,6 +22,7 @@ dependencies:
- termcolor
- Pillow
- psutil
- ffmpeg
- pip:
- --extra-index-url https://download.pytorch.org/whl/rocm5.2
- torch==1.13.1+rocm5.2
@ -47,4 +48,5 @@ dependencies:
- peft==0.3.0
- windows-curses; sys_platform == 'win32'
- pynvml
- https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+rocm5.4.2-cp38-cp38-linux_x86_64.whl
- https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+rocm5.4.2-cp38-cp38-linux_x86_64.whl
- omegaconf

View File

@ -795,8 +795,6 @@ gensettingstf = [
"sub_path": "UI",
"classname": "story",
"name": "gen_audio",
"extra_classes": "var_sync_alt_system_experimental_features"
,
"ui_level": 1
},
{

View File

@ -21,6 +21,7 @@ queue = None
multi_story = False
global enable_whitelist
enable_whitelist = False
slow_tts_message_shown = False
if importlib.util.find_spec("tortoise") is not None:
from tortoise import api
@ -100,6 +101,28 @@ def process_variable_changes(socketio, classname, name, value, old_value, debug_
else:
socketio.emit("var_changed", {"classname": classname, "name": name, "old_value": "*" * len(old_value) if old_value is not None else "", "value": "*" * len(value) if value is not None else "", "transmit_time": transmit_time}, include_self=True, broadcast=True, room=room)
def basic_send(socketio, classname, event, data):
#Get which room we'll send the messages to
global multi_story
if multi_story:
if classname != 'story':
room = 'UI_2'
else:
if has_request_context():
room = 'default' if 'story' not in session else session['story']
else:
logger.error("We tried to access the story register outside of an http context. Will not work in multi-story mode")
return
else:
room = "UI_2"
if not has_request_context():
if queue is not None:
#logger.debug("Had to use queue")
queue.put([event, data, {"broadcast":True, "room":room}])
else:
if socketio is not None:
socketio.emit(event, data, include_self=True, broadcast=True, room=room)
class koboldai_vars(object):
def __init__(self, socketio):
self._model_settings = model_settings(socketio, self)
@ -1420,6 +1443,7 @@ class KoboldStoryRegister(object):
self.make_audio_thread_slow = None
self.make_audio_queue_slow = multiprocessing.Queue()
self.probability_buffer = None
self.audio_status = {}
for item in sequence:
self.append(item)
@ -1577,6 +1601,13 @@ class KoboldStoryRegister(object):
if "Original Text" not in json_data["actions"][item]:
json_data["actions"][item]["Original Text"] = json_data["actions"][item]["Selected Text"]
if "audio_gen" not in json_data["actions"][item]:
json_data["actions"][item]["audio_gen"] = 0
if "image_gen" not in json_data["actions"][item]:
json_data["actions"][item]["image_gen"] = False
temp[int(item)] = json_data['actions'][item]
if int(item) >= self.action_count-100: #sending last 100 items to UI
@ -2056,9 +2087,14 @@ class KoboldStoryRegister(object):
return action_text_split
def gen_audio(self, action_id=None, overwrite=True):
if self.story_settings.gen_audio and self._koboldai_vars.experimental_features:
if action_id is None:
action_id = self.action_count
if action_id is None:
action_id = self.action_count
if overwrite:
if action_id != -1:
self.actions[action_id]["audio_gen"] = 0
basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]})
if self.story_settings.gen_audio:
if self.tts_model is None:
language = 'en'
@ -2075,33 +2111,41 @@ class KoboldStoryRegister(object):
if overwrite or not os.path.exists(filename):
if action_id == -1:
self.make_audio_queue.put((self._koboldai_vars.prompt, filename))
self.make_audio_queue.put((self._koboldai_vars.prompt, filename, action_id))
else:
self.make_audio_queue.put((self.actions[action_id]['Selected Text'], filename))
if self.make_audio_thread_slow is None or not self.make_audio_thread_slow.is_alive():
self.make_audio_thread_slow = threading.Thread(target=self.create_wave_slow, args=(self.make_audio_queue_slow, ))
self.make_audio_thread_slow.start()
self.make_audio_queue.put((self.actions[action_id]['Selected Text'], filename, action_id))
if self.make_audio_thread is None or not self.make_audio_thread.is_alive():
self.make_audio_thread = threading.Thread(target=self.create_wave, args=(self.make_audio_queue, ))
self.make_audio_thread.start()
elif not overwrite and os.path.exists(filename):
if action_id != -1:
self.actions[action_id]["audio_gen"] = 1
if overwrite or not os.path.exists(filename_slow):
if action_id == -1:
self.make_audio_queue_slow.put((self._koboldai_vars.prompt, filename_slow))
self.make_audio_queue_slow.put((self._koboldai_vars.prompt, filename_slow, action_id))
else:
self.make_audio_queue_slow.put((self.actions[action_id]['Selected Text'], filename_slow))
self.make_audio_queue_slow.put((self.actions[action_id]['Selected Text'], filename_slow, action_id))
if self.make_audio_thread_slow is None or not self.make_audio_thread_slow.is_alive():
self.make_audio_thread_slow = threading.Thread(target=self.create_wave_slow, args=(self.make_audio_queue_slow, ))
self.make_audio_thread_slow.start()
elif not overwrite and os.path.exists(filename_slow):
if action_id != -1:
self.actions[action_id]["audio_gen"] = 2
basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]})
def create_wave(self, make_audio_queue):
import pydub
sample_rate = 24000
speaker = 'en_5'
while not make_audio_queue.empty():
(text, filename) = make_audio_queue.get()
(text, filename, action_id) = make_audio_queue.get()
logger.info("Creating audio for {}".format(os.path.basename(filename)))
if text.strip() == "":
shutil.copy("data/empty_audio.ogg", filename)
else:
if len(text) > 2000:
if len(text) > 1000:
text = self.sentence_re.findall(text)
else:
text = [text]
@ -2116,27 +2160,44 @@ class KoboldStoryRegister(object):
output = pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels)
else:
output = output + pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels)
output.export(filename, format="ogg", bitrate="16k")
if output is not None:
output.export(filename, format="ogg", bitrate="16k")
if action_id != -1 and self.actions[action_id]["audio_gen"] == 0:
self.actions[action_id]["audio_gen"] = 1
basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]})
def create_wave_slow(self, make_audio_queue_slow):
import pydub
global slow_tts_message_shown
sample_rate = 24000
speaker = 'train_daws'
if importlib.util.find_spec("tortoise") is None and not slow_tts_message_shown:
logger.info("Disabling slow (and higher quality) tts as it's not installed")
slow_tts_message_shown=True
if self.tortoise is None and importlib.util.find_spec("tortoise") is not None:
self.tortoise=api.TextToSpeech()
self.tortoise=api.TextToSpeech(use_deepspeed=os.environ.get('deepspeed', "false").lower()=="true", kv_cache=os.environ.get('kv_cache', "true").lower()=="true", half=True)
if importlib.util.find_spec("tortoise") is not None:
voice_samples, conditioning_latents = load_voices([speaker])
while not make_audio_queue_slow.empty():
start_time = time.time()
(text, filename) = make_audio_queue_slow.get()
(text, filename, action_id) = make_audio_queue_slow.get()
text_length = len(text)
logger.info("Creating audio for {}".format(os.path.basename(filename)))
if text.strip() == "":
shutil.copy("data/empty_audio.ogg", filename)
else:
if len(text) > 20000:
if len(self.tortoise.tokenizer.encode(text)) > 400:
text = self.sentence_re.findall(text)
i=0
while i <= len(text)-2:
if len(self.tortoise.tokenizer.encode(text[i] + text[i+1])) < 400:
text[i] = text[i] + text[i+1]
del text[i+1]
else:
i+=1
else:
text = [text]
output = None
@ -2147,11 +2208,16 @@ class KoboldStoryRegister(object):
output = pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels)
else:
output = output + pydub.AudioSegment(np.int16(audio * 2 ** 15).tobytes(), frame_rate=sample_rate, sample_width=2, channels=channels)
output.export(filename, format="ogg", bitrate="16k")
if output is not None:
output.export(filename, format="ogg", bitrate="16k")
if action_id != -1:
self.actions[action_id]["audio_gen"] = 2
basic_send(self._socketio, "story", "set_audio_status", {"id": action_id, "action": self.actions[action_id]})
logger.info("Slow audio took {} for {} characters".format(time.time()-start_time, text_length))
def gen_all_audio(self, overwrite=False):
if self.story_settings.gen_audio and self._koboldai_vars.experimental_features:
if self.story_settings.gen_audio:
logger.info("Generating audio for any missing actions")
for i in reversed([-1]+list(self.actions.keys())):
self.gen_audio(i, overwrite=False)
#else:

View File

@ -51,4 +51,4 @@ pynvml
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.0/flash_attn-2.3.0+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; sys_platform == 'linux' and python_version == '3.10'
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.0/flash_attn-2.3.0+cu118torch2.0cxx11abiFALSE-cp38-cp38-linux_x86_64.whl; sys_platform == 'linux' and python_version == '3.8'
xformers==0.0.21
exllamav2==0.0.4
omegaconf

View File

@ -34,4 +34,5 @@ flask_compress
ijson
ftfy
pydub
sentencepiece
sentencepiece
omegaconf

View File

@ -1932,7 +1932,7 @@ body {
}
.tts_controls.hidden[story_gen_audio="true"] {
display: inherit !important;
display: flex !important;
}
.inputrow .tts_controls {
@ -1942,6 +1942,49 @@ body {
width: 100%;
text-align: center;
overflow: hidden;
flex-direction: row;
}
.inputrow .tts_controls div {
padding: 0px;
height: 100%;
width: 100%;
text-align: center;
overflow: hidden;
display: flex;
flex-direction: column;
/*flex-basis: 100%;*/
}
.audio_status_action {
flex-basis: 100%;
}
.audio_status_action[status="2"] {
background-color: green;
}
.audio_status_action[status="1"] {
background-color: yellow;
}
.audio_status_action[status="0"] {
background-color: red;
}
.audio_status_action[status="-1"] {
display: none;
}
.inputrow .tts_controls div button {
flex-basis: 100%;
}
.inputrow .tts_controls .audio_status {
padding: 0px;
height: 100%;
width: 2px;
display: flex;
flex-direction: column;
}
.inputrow .back {

View File

@ -45,6 +45,7 @@ socket.on("show_error_notification", function(data) { reportError(data.title, da
socket.on("generated_wi", showGeneratedWIData);
socket.on("stream_tokens", stream_tokens);
socket.on("show_options", show_options);
socket.on("set_audio_status", set_audio_status);
//socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});});
// Must be done before any elements are made; we track their changes.
@ -148,6 +149,7 @@ const context_menu_actions = {
{label: "Add to World Info Entry", icon: "auto_stories", enabledOn: "SELECTION", click: push_selection_to_world_info},
{label: "Add as Bias", icon: "insights", enabledOn: "SELECTION", click: push_selection_to_phrase_bias},
{label: "Retry from here", icon: "refresh", enabledOn: "CARET", click: retry_from_here},
{label: "Generate image for here", icon: "image", enabledOn: "CARET", click: generate_image},
null,
{label: "Take Screenshot", icon: "screenshot_monitor", enabledOn: "SELECTION", click: screenshot_selection},
// Not implemented! See view_selection_probabiltiies
@ -637,6 +639,7 @@ function process_actions_data(data) {
actions_data[parseInt(action.id)] = action.action;
do_story_text_updates(action);
create_options(action);
set_audio_status(action);
}
clearTimeout(game_text_scroll_timeout);
@ -648,6 +651,26 @@ function process_actions_data(data) {
}
function set_audio_status(action) {
if (!('audio_gen' in action.action)) {
action.action.audio_gen = 0;
}
if (!(document.getElementById("audio_gen_status_"+action.id))) {
sp = document.createElement("SPAN");
sp.id = "audio_gen_status_"+action.id
sp.classList.add("audio_status_action");
sp.setAttribute("status", -1);
document.getElementById("audio_status").appendChild(sp);
}
document.getElementById("audio_gen_status_"+action.id).setAttribute("status", action.action.audio_gen);
//Delete empty actions
if (action.action['Selected Text'] == "") {
console.log("disabling status");
document.getElementById("audio_gen_status_"+action.id).setAttribute("status", -1);
}
}
function parseChatMessages(text) {
let messages = [];
@ -703,9 +726,9 @@ function do_story_text_updates(action) {
item.classList.add("rawtext");
item.setAttribute("chunk", action.id);
item.setAttribute("tabindex", parseInt(action.id)+1);
//item.addEventListener("focus", (event) => {
// set_edit(event.target);
//});
item.addEventListener("focus", (event) => {
set_image_action(action.id);
});
//need to find the closest element
closest_element = document.getElementById("story_prompt");
@ -1387,6 +1410,29 @@ function redrawPopup() {
}
this.parentElement.classList.add("selected");
};
td.ondblclick = function () {
let accept = document.getElementById("popup_accept");
if (this.getAttribute("valid") == "true") {
accept.classList.remove("disabled");
accept.disabled = false;
accept.setAttribute("selected_value", this.id);
socket.emit(document.getElementById("popup_accept").getAttribute("emit"), this.id);
closePopups();
} else {
accept.setAttribute("selected_value", "");
accept.classList.add("disabled");
accept.disabled = true;
if (this.getAttribute("folder") == "true") {
socket.emit("popup_change_folder", this.id);
}
}
let popup_list = document.getElementById('popup_list').getElementsByClassName("selected");
for (item of popup_list) {
item.classList.remove("selected");
}
this.parentElement.classList.add("selected");
};
tr.append(td);
}
@ -3420,7 +3466,7 @@ function fix_dirty_game_text() {
if (dirty_chunks.includes("game_text")) {
dirty_chunks = dirty_chunks.filter(item => item != "game_text");
console.log("Firing Fix messed up text");
//console.log("Firing Fix messed up text");
//Fixing text outside of chunks
for (node of game_text.childNodes) {
if ((!(node instanceof HTMLElement) || !node.hasAttribute("chunk")) && (node.textContent.trim() != "")) {
@ -3767,6 +3813,21 @@ function stop_tts() {
}
}
function download_tts() {
document.getElementById("download_tts").innerText = "hourglass_empty";
socket.emit("gen_full_audio", {}, download_actual_file_tts);
}
function download_actual_file_tts(data) {
if (data) {
var link = document.createElement("a");
link.download = document.getElementsByClassName("var_sync_story_story_name ")[0].text+".ogg";
link.href = "/audio_full";
link.click();
document.getElementById("download_tts").innerText = "download";
}
}
function finished_tts() {
next_action = parseInt(document.getElementById("reader").getAttribute("action_id"))+1;
action_count = parseInt(document.getElementById("action_count").textContent);
@ -3800,6 +3861,29 @@ function tts_playing() {
}
}
function set_image_action(action_id) {
console.log(action_id);
socket.emit("get_story_image", {action_id: action_id}, change_image);
}
function change_image(data) {
image_area = document.getElementById("action image");
let maybeImage = image_area.getElementsByClassName("action_image")[0];
if (maybeImage) maybeImage.remove();
$el("#image-loading").classList.add("hidden");
if (data != undefined) {
var image = new Image();
image.src = 'data:image/png;base64,'+data;
image.classList.add("action_image");
image.setAttribute("context-menu", "generated-image");
image.addEventListener("click", imgGenView);
image_area.appendChild(image);
}
}
function view_selection_probabilities() {
// Not quite sure how this should work yet. Probabilities are obviously on
// the token level, which we have no UI representation of. There are other
@ -7027,6 +7111,22 @@ $el("#generate-image-button").addEventListener("click", function() {
socket.emit("generate_image", {});
});
function generate_image() {
let chunk = null;
for (element of document.getElementsByClassName("editing")) {
if (element.id == 'story_prompt') {
chunk = -1
} else {
chunk = parseInt(element.id.split(" ").at(-1));
}
}
if (chunk != null) {
socket.emit("generate_image", {action_id: chunk});
}
}
/* -- Shiny New Chat -- */
function addMessage(author, content, actionId, afterMsgEl=null, time=null) {
if (!time) time = Number(new Date());

View File

@ -54,7 +54,7 @@
</div>
<div class="gametext" id="Selected Text" contenteditable="false" tabindex="0" onkeyup="return set_edit(event);">
<span id="story_prompt" class="var_sync_story_prompt var_sync_alt_story_prompt_in_ai rawtext" chunk="-1"></span></div><!--don't move the /div down or it'll cause odd spacing issues in the UI--->
<span id="story_prompt" class="var_sync_story_prompt var_sync_alt_story_prompt_in_ai rawtext" chunk="-1" onfocus='set_image_action(-1);'></span></div><!--don't move the /div down or it'll cause odd spacing issues in the UI--->
</div>
<!------------ Sequences --------------------->
@ -107,8 +107,13 @@
</div>
</div><br>
<span class="tts_controls hidden var_sync_alt_story_gen_audio">
<div class="audio_status" id="audio_status">
</div>
<div>
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='play_pause_tts()' aria-label="play"><span id="play_tts" class="material-icons-outlined" style="font-size: 1.4em;">play_arrow</span></button>
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='stop_tts()' aria-label="play"><span id="stop_tts" class="material-icons-outlined" style="font-size: 1.4em;">stop</span></button>
<button type="button" class="btn action_button" style="width: 30px; padding: 0px;" onclick='download_tts()' aria-label="play"><span id="download_tts" class="material-icons-outlined" style="font-size: 1.4em;">download</span></button>
</div>
</span>
<button type="button" class="btn action_button submit var_sync_alt_system_aibusy" system_aibusy=False id="btnsubmit" onclick="storySubmit();" context-menu="submit-button">Submit</button>
<button type="button" class="btn action_button submited var_sync_alt_system_aibusy" system_aibusy=False id="btnsent"><img id="thinking" src="static/thinking.gif" class="force_center" onclick="socket.emit('abort','');"></button>