Better Image Generation

This commit is contained in:
ebolam
2022-09-20 14:49:18 -04:00
parent 791f863791
commit 8dfe0eba8a
3 changed files with 45 additions and 19 deletions

View File

@@ -1228,6 +1228,9 @@ def general_startup(override_args=None):
parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)") parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)")
parser.add_argument("--customsettings", help="Preloads arguements from json file. You only need to provide the location of the json file. Use customsettings.json template file. It can be renamed if you wish so that you can store multiple configurations. Leave any settings you want as default as null. Any values you wish to set need to be in double quotation marks") parser.add_argument("--customsettings", help="Preloads arguements from json file. You only need to provide the location of the json file. Use customsettings.json template file. It can be renamed if you wish so that you can store multiple configurations. Leave any settings you want as default as null. Any values you wish to set need to be in double quotation marks")
parser.add_argument("--no_ui", action='store_true', default=False, help="Disables the GUI and Socket.IO server while leaving the API server running.") parser.add_argument("--no_ui", action='store_true', default=False, help="Disables the GUI and Socket.IO server while leaving the API server running.")
parser.add_argument("--summarizer_model", action='store', default="philschmid/bart-large-cnn-samsum", help="Huggingface model to use for summarization. Defaults to sshleifer/distilbart-cnn-12-6")
parser.add_argument("--max_summary_length", action='store', default=100, help="Maximum size for summary to send to image generation")
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen") parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen") parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
@@ -1281,6 +1284,8 @@ def general_startup(override_args=None):
old_emit = socketio.emit old_emit = socketio.emit
socketio.emit = new_emit socketio.emit = new_emit
args.max_summary_length = int(args.max_summary_length)
koboldai_vars.model = args.model; koboldai_vars.model = args.model;
koboldai_vars.revision = args.revision koboldai_vars.revision = args.revision
@@ -3639,11 +3644,11 @@ def do_connect():
if request.args.get("rely") == "true": if request.args.get("rely") == "true":
return return
join_room("UI_{}".format(request.args.get('ui'))) join_room("UI_{}".format(request.args.get('ui')))
print("Joining Room UI_{}".format(request.args.get('ui'))) logger.debug("Joining Room UI_{}".format(request.args.get('ui')))
if request.args.get("ui") == "2": if request.args.get("ui") == "2":
ui2_connect() ui2_connect()
return return
print("{0}Client connected!{1}".format(colors.GREEN, colors.END)) logger.debug("{0}Client connected!{1}".format(colors.GREEN, colors.END))
emit('from_server', {'cmd': 'setchatname', 'data': koboldai_vars.chatname}, room="UI_1") emit('from_server', {'cmd': 'setchatname', 'data': koboldai_vars.chatname}, room="UI_1")
emit('from_server', {'cmd': 'setanotetemplate', 'data': koboldai_vars.authornotetemplate}, room="UI_1") emit('from_server', {'cmd': 'setanotetemplate', 'data': koboldai_vars.authornotetemplate}, room="UI_1")
emit('from_server', {'cmd': 'connected', 'smandelete': koboldai_vars.smandelete, 'smanrename': koboldai_vars.smanrename, 'modelname': getmodelname()}, room="UI_1") emit('from_server', {'cmd': 'connected', 'smandelete': koboldai_vars.smandelete, 'smanrename': koboldai_vars.smanrename, 'modelname': getmodelname()}, room="UI_1")
@@ -8157,6 +8162,9 @@ def UI_2_save_revision(data):
#==================================================================# #==================================================================#
@socketio.on("generate_image") @socketio.on("generate_image")
def UI_2_generate_image(data): def UI_2_generate_image(data):
socketio.start_background_task(generate_image_in_background)
def generate_image_in_background():
koboldai_vars.generating_image = True koboldai_vars.generating_image = True
#get latest action #get latest action
if len(koboldai_vars.actions) > 0: if len(koboldai_vars.actions) > 0:
@@ -8183,8 +8191,14 @@ def UI_2_generate_image(data):
#If we have > 4 keys, use those otherwise use sumarization #If we have > 4 keys, use those otherwise use sumarization
if len(keys) < 4: if len(keys) < 4:
from transformers import pipeline as summary_pipeline from transformers import pipeline as summary_pipeline
summarizer = summary_pipeline("summarization", model="sshleifer/distilbart-xsum-12-1") start_time = time.time()
if koboldai_vars.summarizer is None:
koboldai_vars.summarizer = summary_pipeline("summarization", model=args.summarizer_model)
#koboldai_vars.summarizer = summary_pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", cache="./cache")
logger.debug("Time to load summarizer: {}".format(time.time()-start_time))
#text to summarize: #text to summarize:
start_time = time.time()
if len(koboldai_vars.actions) < 5: if len(koboldai_vars.actions) < 5:
text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt]) text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt])
else: else:
@@ -8192,12 +8206,17 @@ def UI_2_generate_image(data):
global old_transfomers_functions global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']] keys = [koboldai_vars.summarizer(text, max_length=args.max_summary_length, min_length=30, do_sample=False)[0]['summary_text']]
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
logger.debug("Time to summarize: {}".format(time.time()-start_time))
logger.debug("Original Text: {}".format(text))
logger.debug("Summarized Text: {}".format(keys[0]))
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting', art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
#If we don't have a GPU, use horde if we're allowed to #If we don't have a GPU, use horde if we're allowed to
start_time = time.time()
if (not koboldai_vars.hascuda and koboldai_vars.img_gen_priority != 0) or koboldai_vars.img_gen_priority == 3: if (not koboldai_vars.hascuda and koboldai_vars.img_gen_priority != 0) or koboldai_vars.img_gen_priority == 3:
b64_data = text2img_horde(", ".join(keys), art_guide = art_guide) b64_data = text2img_horde(", ".join(keys), art_guide = art_guide)
else: else:
@@ -8212,16 +8231,16 @@ def UI_2_generate_image(data):
b64_data = text2img_horde(", ".join(keys), art_guide = art_guide) b64_data = text2img_horde(", ".join(keys), art_guide = art_guide)
elif koboldai_vars.img_gen_priority != 0: elif koboldai_vars.img_gen_priority != 0:
b64_data = text2img_horde(", ".join(keys), art_guide = art_guide) b64_data = text2img_horde(", ".join(keys), art_guide = art_guide)
logger.debug("Time to Generate Image {}".format(time.time()-start_time))
koboldai_vars.picture = b64_data koboldai_vars.picture = b64_data
koboldai_vars.picture_prompt = ", ".join(keys) koboldai_vars.picture_prompt = ", ".join(keys)
koboldai_vars.generating_image = False koboldai_vars.generating_image = False
#emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
@logger.catch @logger.catch
def text2img_local(prompt, art_guide="", filename="new.png"): def text2img_local(prompt, art_guide="", filename="new.png"):
start_time = time.time() start_time = time.time()
print("Generating Image") logger.debug("Generating Image")
koboldai_vars.aibusy = True koboldai_vars.aibusy = True
koboldai_vars.generating_image = True koboldai_vars.generating_image = True
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
@@ -8230,15 +8249,15 @@ def text2img_local(prompt, art_guide="", filename="new.png"):
if koboldai_vars.image_pipeline is None: if koboldai_vars.image_pipeline is None:
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, cache="./stable-diffusion-v1-4").to("cuda") pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, cache="./stable-diffusion-v1-4").to("cuda")
else: else:
pipe = koboldai_vars.image_pipeline pipe = koboldai_vars.image_pipeline.to("cuda")
print("time to load: {}".format(time.time() - start_time)) logger.debug("time to load: {}".format(time.time() - start_time))
from torch import autocast from torch import autocast
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt)["sample"][0] image = pipe(prompt, num_inference_steps=35)["sample"][0]
buffered = BytesIO() buffered = BytesIO()
image.save(buffered, format="JPEG") image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('ascii') img_str = base64.b64encode(buffered.getvalue()).decode('ascii')
print("time to generate: {}".format(time.time() - start_time)) logger.debug("time to generate: {}".format(time.time() - start_time))
if koboldai_vars.keep_img_gen_in_memory: if koboldai_vars.keep_img_gen_in_memory:
pipe.to("cpu") pipe.to("cpu")
if koboldai_vars.image_pipeline is None: if koboldai_vars.image_pipeline is None:
@@ -8249,14 +8268,14 @@ def text2img_local(prompt, art_guide="", filename="new.png"):
torch.cuda.empty_cache() torch.cuda.empty_cache()
koboldai_vars.generating_image = False koboldai_vars.generating_image = False
koboldai_vars.aibusy = False koboldai_vars.aibusy = False
print("time to unload: {}".format(time.time() - start_time)) logger.debug("time to unload: {}".format(time.time() - start_time))
return img_str return img_str
@logger.catch @logger.catch
def text2img_horde(prompt, def text2img_horde(prompt,
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting', art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
filename = "story_art.png"): filename = "story_art.png"):
print("Generating Image using Horde") logger.debug("Generating Image using Horde")
koboldai_vars.generating_image = True koboldai_vars.generating_image = True
final_imgen_params = { final_imgen_params = {
"n": 1, "n": 1,
@@ -8284,12 +8303,12 @@ def text2img_horde(prompt,
else: else:
final_filename = filename final_filename = filename
img.save(final_filename) img.save(final_filename)
print("Saved Image") logger.debug("Saved Image")
koboldai_vars.generating_image = False koboldai_vars.generating_image = False
return(b64img) return(b64img)
else: else:
koboldai_vars.generating_image = False koboldai_vars.generating_image = False
print(submit_req.text) logger.error(submit_req.text)
def get_items_locations_from_text(text): def get_items_locations_from_text(text):
# load model and tokenizer # load model and tokenizer

View File

@@ -549,7 +549,7 @@ gensettingstf = [
"tooltip": "If enabled, the system will keep the model in memory speeding up image generation times", "tooltip": "If enabled, the system will keep the model in memory speeding up image generation times",
"menu_path": "Interface", "menu_path": "Interface",
"sub_path": "Images", "sub_path": "Images",
"classname": "user", "classname": "system",
"name": "keep_img_gen_in_memory" "name": "keep_img_gen_in_memory"
}, },
] ]

View File

@@ -506,7 +506,10 @@ class model_settings(settings):
self.tqdm_progress = 0 self.tqdm_progress = 0
else: else:
self.tqdm.update(value-old_value) self.tqdm.update(value-old_value)
self.tqdm_progress = 0 if self.total_download_chunks==0 else round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1) if self.total_download_chunks is not None:
self.tqdm_progress = 0 if self.total_download_chunks==0 else round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1)
else:
self.tqdm_progress = 0
if self.tqdm.format_dict['rate'] is not None: if self.tqdm.format_dict['rate'] is not None:
self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_download_chunks-self.downloaded_chunks)/self.tqdm.format_dict['rate']))) self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_download_chunks-self.downloaded_chunks)/self.tqdm.format_dict['rate'])))
@@ -715,7 +718,6 @@ class user_settings(settings):
self.show_probs = False # Whether or not to show token probabilities self.show_probs = False # Whether or not to show token probabilities
self.beep_on_complete = False self.beep_on_complete = False
self.img_gen_priority = 1 self.img_gen_priority = 1
self.keep_img_gen_in_memory = False
def __setattr__(self, name, value): def __setattr__(self, name, value):
@@ -727,8 +729,8 @@ class user_settings(settings):
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value) process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
class system_settings(settings): class system_settings(settings):
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'image_pipeline'] local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'image_pipeline', 'summarizer']
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'image_pipeline'] no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'image_pipeline', 'summarizer']
settings_name = "system" settings_name = "system"
def __init__(self, socketio): def __init__(self, socketio):
self.socketio = socketio self.socketio = socketio
@@ -807,6 +809,8 @@ class system_settings(settings):
self.sh_apikey = "" # API key to use for txt2img from the Stable Horde. self.sh_apikey = "" # API key to use for txt2img from the Stable Horde.
self.generating_image = False #The current status of image generation self.generating_image = False #The current status of image generation
self.image_pipeline = None self.image_pipeline = None
self.summarizer = None
self.keep_img_gen_in_memory = False
self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost
@@ -830,6 +834,9 @@ class system_settings(settings):
self.socketio.emit('from_server', {'cmd': 'spstatitems', 'data': {self.spfilename: self.spmeta} if self.allowsp and len(self.spfilename) else {}}, namespace=None, broadcast=True, room="UI_1") self.socketio.emit('from_server', {'cmd': 'spstatitems', 'data': {self.spfilename: self.spmeta} if self.allowsp and len(self.spfilename) else {}}, namespace=None, broadcast=True, room="UI_1")
super().__setattr__("sp_changed", False) super().__setattr__("sp_changed", False)
if name == 'keep_img_gen_in_memory' and value == False:
self.image_pipeline = None
if name == 'horde_share': if name == 'horde_share':
if self.on_colab == False: if self.on_colab == False:
if os.path.exists("./KoboldAI-Horde"): if os.path.exists("./KoboldAI-Horde"):