mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Streamline image generation and save to directory
This commit is contained in:
162
aiserver.py
162
aiserver.py
@@ -9231,8 +9231,10 @@ def UI_2_generate_image_from_story(data):
|
||||
#get latest action
|
||||
if len(koboldai_vars.actions) > 0:
|
||||
action = koboldai_vars.actions[-1]
|
||||
action_id = len(koboldai_vars.actions) - 1
|
||||
else:
|
||||
action = koboldai_vars.prompt
|
||||
action_id = -1
|
||||
#Get matching world info entries
|
||||
keys = []
|
||||
for wi in koboldai_vars.worldinfo_v2:
|
||||
@@ -9273,64 +9275,122 @@ def UI_2_generate_image_from_story(data):
|
||||
keys = [summarize(text, max_length=max_length)]
|
||||
logger.debug("Text from summarizer: {}".format(keys[0]))
|
||||
|
||||
generate_story_image(", ".join(keys), art_guide=art_guide)
|
||||
prompt = ", ".join(keys)
|
||||
generate_story_image(
|
||||
", ".join([part for part in [prompt, art_guide] if part]),
|
||||
file_prefix=f"action_{action_id}",
|
||||
display_prompt=prompt,
|
||||
log_data={"actionId": action_id}
|
||||
)
|
||||
|
||||
@socketio.on("generate_image_from_prompt")
|
||||
@logger.catch
|
||||
def UI_2_generate_image_from_prompt(prompt: str):
|
||||
eventlet.sleep(0)
|
||||
generate_story_image(prompt)
|
||||
generate_story_image(prompt, file_prefix="prompt", generation_type="direct_prompt")
|
||||
|
||||
def generate_story_image(prompt: str, art_guide: str = "") -> None:
|
||||
def log_image_generation(
|
||||
prompt: str,
|
||||
display_prompt: str,
|
||||
file_name: str,
|
||||
generation_type: str,
|
||||
other_data: Optional[dict] = None
|
||||
) -> None:
|
||||
# In the future it might be nice to have some UI where you can search past
|
||||
# generations or something like that
|
||||
db_path = os.path.join(koboldai_vars.save_paths.generated_images, "db.json")
|
||||
|
||||
try:
|
||||
with open(db_path, "r") as file:
|
||||
j = json.load(file)
|
||||
except FileNotFoundError:
|
||||
j = []
|
||||
|
||||
if not isinstance(j, list):
|
||||
logger.warning("Image database is corrupted! Will not add new entry.")
|
||||
return
|
||||
|
||||
|
||||
log_data = {
|
||||
"prompt": prompt,
|
||||
"fileName": file_name,
|
||||
"type": generation_type or None,
|
||||
"displayPrompt": display_prompt
|
||||
}
|
||||
log_data.update(other_data or {})
|
||||
j.append(log_data)
|
||||
|
||||
with open(db_path, "w") as file:
|
||||
json.dump(j, file)
|
||||
|
||||
def generate_story_image(
|
||||
prompt: str,
|
||||
file_prefix: str = "image",
|
||||
generation_type: str = "",
|
||||
display_prompt: Optional[str] = None,
|
||||
log_data: Optional[dict] = None
|
||||
|
||||
) -> None:
|
||||
# This function is a wrapper around generate_image() that integrates the
|
||||
# result with the story (read: puts it in the corner of the screen).
|
||||
|
||||
if not display_prompt:
|
||||
display_prompt = prompt
|
||||
koboldai_vars.picture_prompt = display_prompt
|
||||
|
||||
start_time = time.time()
|
||||
koboldai_vars.generating_image = True
|
||||
|
||||
b64_data = generate_image(prompt, art_guide=art_guide)
|
||||
image = generate_image(prompt)
|
||||
koboldai_vars.generating_image = False
|
||||
|
||||
if not image:
|
||||
return
|
||||
|
||||
if os.path.exists(koboldai_vars.save_paths.generated_images):
|
||||
# Only save image if this is a saved story
|
||||
file_name = f"{file_prefix}_{int(time.time())}.png"
|
||||
image.save(os.path.join(koboldai_vars.save_paths.generated_images, file_name))
|
||||
log_image_generation(prompt, display_prompt, file_name, generation_type, log_data)
|
||||
|
||||
logger.debug("Time to Generate Image {}".format(time.time()-start_time))
|
||||
|
||||
koboldai_vars.picture = b64_data
|
||||
koboldai_vars.picture_prompt = prompt
|
||||
koboldai_vars.generating_image = False
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
b64_data = base64.b64encode(buffer.getvalue()).decode("ascii")
|
||||
|
||||
def generate_image(prompt: str, art_guide: str = "") -> Optional[str]:
|
||||
koboldai_vars.picture = b64_data
|
||||
|
||||
|
||||
def generate_image(prompt: str) -> Optional[Image.Image]:
|
||||
if koboldai_vars.img_gen_priority == 4:
|
||||
# Check if stable-diffusion-webui API option selected and use that if found.
|
||||
return text2img_api(prompt, art_guide=art_guide)
|
||||
return text2img_api(prompt)
|
||||
elif ((not koboldai_vars.hascuda or not os.path.exists("models/stable-diffusion-v1-4")) and koboldai_vars.img_gen_priority != 0) or koboldai_vars.img_gen_priority == 3:
|
||||
# If we don't have a GPU, use horde if we're allowed to
|
||||
return text2img_horde(prompt, art_guide=art_guide)
|
||||
return text2img_horde(prompt)
|
||||
|
||||
memory = torch.cuda.get_device_properties(0).total_memory
|
||||
|
||||
# We aren't being forced to use horde, so now let's figure out if we should use local
|
||||
if memory - torch.cuda.memory_reserved(0) >= 6000000000:
|
||||
# We have enough vram, just do it locally
|
||||
return text2img_local(prompt, art_guide=art_guide)
|
||||
return text2img_local(prompt)
|
||||
elif memory > 6000000000 and koboldai_vars.img_gen_priority <= 1:
|
||||
# We could do it locally by swapping the model out
|
||||
print("Could do local or online")
|
||||
return text2img_horde(prompt, art_guide=art_guide)
|
||||
return text2img_horde(prompt)
|
||||
elif koboldai_vars.img_gen_priority != 0:
|
||||
return text2img_horde(prompt, art_guide=art_guide)
|
||||
return text2img_horde(prompt)
|
||||
|
||||
raise RuntimeError("Unable to decide image generation backend. Please report this.")
|
||||
|
||||
|
||||
@logger.catch
|
||||
def text2img_local(prompt,
|
||||
art_guide="",
|
||||
filename="new.png"):
|
||||
def text2img_local(prompt: str) -> Optional[Image.Image]:
|
||||
start_time = time.time()
|
||||
logger.debug("Generating Image")
|
||||
koboldai_vars.aibusy = True
|
||||
koboldai_vars.generating_image = True
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import base64
|
||||
from io import BytesIO
|
||||
if koboldai_vars.image_pipeline is None:
|
||||
pipe = tpool.execute(StableDiffusionPipeline.from_pretrained, "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, cache="models/stable-diffusion-v1-4").to("cuda")
|
||||
else:
|
||||
@@ -9343,9 +9403,6 @@ def text2img_local(prompt,
|
||||
with autocast("cuda"):
|
||||
return pipe(prompt, num_inference_steps=num_inference_steps).images[0]
|
||||
image = tpool.execute(get_image, pipe, prompt, num_inference_steps=koboldai_vars.img_gen_steps)
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
logger.debug("time to generate: {}".format(time.time() - start_time))
|
||||
start_time = time.time()
|
||||
if koboldai_vars.keep_img_gen_in_memory:
|
||||
@@ -9356,61 +9413,52 @@ def text2img_local(prompt,
|
||||
koboldai_vars.image_pipeline = None
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
koboldai_vars.generating_image = False
|
||||
koboldai_vars.aibusy = False
|
||||
logger.debug("time to unload: {}".format(time.time() - start_time))
|
||||
return img_str
|
||||
return image
|
||||
|
||||
@logger.catch
|
||||
def text2img_horde(prompt,
|
||||
art_guide = "",
|
||||
filename = "story_art.png"):
|
||||
def text2img_horde(prompt: str) -> Optional[Image.Image]:
|
||||
logger.debug("Generating Image using Horde")
|
||||
koboldai_vars.generating_image = True
|
||||
|
||||
|
||||
final_submit_dict = {
|
||||
"prompt": "{}, {}".format(prompt, art_guide),
|
||||
"prompt": prompt,
|
||||
"trusted_workers": False,
|
||||
"models": [
|
||||
"stable_diffusion"
|
||||
],
|
||||
"params": {
|
||||
"n":1,
|
||||
"n": 1,
|
||||
"nsfw": True,
|
||||
"sampler_name": "k_euler_a",
|
||||
"karras": True,
|
||||
"cfg_scale": koboldai_vars.img_gen_cfg_scale,
|
||||
"steps":koboldai_vars.img_gen_steps,
|
||||
"width":512,
|
||||
"height":512}
|
||||
"steps": koboldai_vars.img_gen_steps,
|
||||
"width": 512,
|
||||
"height": 512
|
||||
}
|
||||
}
|
||||
|
||||
cluster_headers = {'apikey': koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",}
|
||||
|
||||
logger.debug(final_submit_dict)
|
||||
submit_req = requests.post('https://stablehorde.net/api/v2/generate/sync', json = final_submit_dict, headers=cluster_headers)
|
||||
if submit_req.ok:
|
||||
results = submit_req.json()
|
||||
for iter in range(len(results['generations'])):
|
||||
b64img = results['generations'][iter]["img"]
|
||||
base64_bytes = b64img.encode('utf-8')
|
||||
img_bytes = base64.b64decode(base64_bytes)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
if len(results) > 1:
|
||||
final_filename = f"{iter}_{filename}"
|
||||
else:
|
||||
final_filename = filename
|
||||
img.save(final_filename)
|
||||
logger.debug("Saved Image")
|
||||
koboldai_vars.generating_image = False
|
||||
return(b64img)
|
||||
else:
|
||||
koboldai_vars.generating_image = False
|
||||
submit_req = requests.post('https://stablehorde.net/api/v2/generate/sync', json=final_submit_dict, headers=cluster_headers)
|
||||
|
||||
if not submit_req.ok:
|
||||
logger.error(submit_req.text)
|
||||
return
|
||||
|
||||
results = submit_req.json()
|
||||
if len(results["generations"]) > 1:
|
||||
logger.warning(f"Got too many generations, discarding extras. Got {len(results['generations'])}, expected 1.")
|
||||
|
||||
b64img = results["generations"][0]["img"]
|
||||
base64_bytes = b64img.encode("utf-8")
|
||||
img_bytes = base64.b64decode(base64_bytes)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
return img
|
||||
|
||||
@logger.catch
|
||||
def text2img_api(prompt, art_guide=""):
|
||||
def text2img_api(prompt, art_guide="") -> Image.Image:
|
||||
logger.debug("Generating Image using Local SD-WebUI API")
|
||||
koboldai_vars.generating_image = True
|
||||
#The following list are valid properties with their defaults, to add/modify in final_imgen_params. Will refactor configuring values into UI element in future.
|
||||
@@ -9444,7 +9492,7 @@ def text2img_api(prompt, art_guide=""):
|
||||
#"override_settings": {},
|
||||
#"sampler_index": "Euler"
|
||||
final_imgen_params = {
|
||||
"prompt": ", ".join(filter(bool, [prompt, art_guide])),
|
||||
"prompt": prompt,
|
||||
"n_iter": 1,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
@@ -9493,7 +9541,7 @@ def text2img_api(prompt, art_guide=""):
|
||||
show_error_notification("SD Web API Failure", "SD Web API returned no images", do_log=True)
|
||||
return None
|
||||
|
||||
return base64_image
|
||||
return Image.open(BytesIO(base64.b64decode(base64_image)))
|
||||
|
||||
@socketio.on("clear_generated_image")
|
||||
@logger.catch
|
||||
|
Reference in New Issue
Block a user