Working Horde based image generation

This commit is contained in:
ebolam
2022-09-19 14:29:45 -04:00
parent 37d28950c0
commit 0df594f0e5
5 changed files with 76 additions and 9 deletions

View File

@@ -1775,6 +1775,8 @@ def patch_transformers_download():
def patch_transformers():
global transformers
global old_transfomers_functions
old_transfomers_functions = {}
patch_transformers_download()
@@ -1791,9 +1793,11 @@ def patch_transformers():
if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
old_transfomers_functions['PreTrainedModel.from_pretrained'] = PreTrainedModel.from_pretrained
PreTrainedModel.from_pretrained = new_from_pretrained
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
@@ -1821,6 +1825,7 @@ def patch_transformers():
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
XGLMSinusoidalPositionalEmbedding.forward = new_forward
@@ -1840,6 +1845,7 @@ def patch_transformers():
self.model = OPTModel(config)
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.post_init()
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
OPTForCausalLM.__init__ = new_init
@@ -2124,6 +2130,7 @@ def patch_transformers():
break
return self.regeneration_required or self.halt
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
def new_get_stopping_criteria(self, *args, **kwargs):
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
global tokenizer
@@ -8147,14 +8154,48 @@ def UI_2_save_revision(data):
#==================================================================#
@socketio.on("generate_image")
def UI_2_generate_image(data):
prompt = koboldai_vars.calc_ai_text(return_text=True)
print(prompt)
get_items_locations_from_text(prompt)
if 'art_guide' not in data:
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
koboldai_vars.generating_image = True
#get latest action
if len(koboldai_vars.actions) > 0:
action = koboldai_vars.actions[-1]
else:
art_guide = data['art_guide']
#text2img(prompt, art_guide = art_guide)
action = koboldai_vars.prompt
#Get matching world info entries
keys = []
for wi in koboldai_vars.worldinfo_v2:
for key in wi['key']:
if key in action:
#Check to make sure secondary keys are present if needed
if len(wi['keysecondary']) > 0:
for keysecondary in wi['keysecondary']:
if keysecondary in action:
keys.append(key)
break
break
else:
keys.append(key)
break
#If we have > 4 keys, use those otherwise use sumarization
if len(keys) < 4:
from transformers import pipeline as summary_pipeline
summarizer = summary_pipeline("summarization")
#text to summarize:
if len(koboldai_vars.actions) < 5:
text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt])
else:
text = "".join(koboldai_vars.actions[:-5])
global old_transfomers_functions
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']]
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
b64_data = text2img(", ".join(keys), art_guide = art_guide)
emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
@logger.catch
@@ -8189,9 +8230,9 @@ def text2img(prompt,
else:
final_filename = filename
img.save(final_filename)
return(img)
print("Saved Image")
koboldai_vars.generating_image = False
return(b64img)
else:
koboldai_vars.generating_image = False
print(submit_req.text)