mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Working Horde based image generation
This commit is contained in:
57
aiserver.py
57
aiserver.py
@@ -1775,6 +1775,8 @@ def patch_transformers_download():
|
||||
|
||||
def patch_transformers():
|
||||
global transformers
|
||||
global old_transfomers_functions
|
||||
old_transfomers_functions = {}
|
||||
|
||||
patch_transformers_download()
|
||||
|
||||
@@ -1791,9 +1793,11 @@ def patch_transformers():
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
old_transfomers_functions['PreTrainedModel.from_pretrained'] = PreTrainedModel.from_pretrained
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
@@ -1821,6 +1825,7 @@ def patch_transformers():
|
||||
if max_pos > self.weights.size(0):
|
||||
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||
old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward
|
||||
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
||||
|
||||
|
||||
@@ -1840,6 +1845,7 @@ def patch_transformers():
|
||||
self.model = OPTModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__
|
||||
OPTForCausalLM.__init__ = new_init
|
||||
|
||||
|
||||
@@ -2124,6 +2130,7 @@ def patch_transformers():
|
||||
break
|
||||
return self.regeneration_required or self.halt
|
||||
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria
|
||||
def new_get_stopping_criteria(self, *args, **kwargs):
|
||||
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
|
||||
global tokenizer
|
||||
@@ -8147,14 +8154,48 @@ def UI_2_save_revision(data):
|
||||
#==================================================================#
|
||||
@socketio.on("generate_image")
|
||||
def UI_2_generate_image(data):
|
||||
prompt = koboldai_vars.calc_ai_text(return_text=True)
|
||||
print(prompt)
|
||||
get_items_locations_from_text(prompt)
|
||||
if 'art_guide' not in data:
|
||||
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
|
||||
koboldai_vars.generating_image = True
|
||||
#get latest action
|
||||
if len(koboldai_vars.actions) > 0:
|
||||
action = koboldai_vars.actions[-1]
|
||||
else:
|
||||
art_guide = data['art_guide']
|
||||
#text2img(prompt, art_guide = art_guide)
|
||||
action = koboldai_vars.prompt
|
||||
#Get matching world info entries
|
||||
keys = []
|
||||
for wi in koboldai_vars.worldinfo_v2:
|
||||
for key in wi['key']:
|
||||
if key in action:
|
||||
#Check to make sure secondary keys are present if needed
|
||||
if len(wi['keysecondary']) > 0:
|
||||
for keysecondary in wi['keysecondary']:
|
||||
if keysecondary in action:
|
||||
keys.append(key)
|
||||
break
|
||||
break
|
||||
else:
|
||||
keys.append(key)
|
||||
break
|
||||
|
||||
|
||||
#If we have > 4 keys, use those otherwise use sumarization
|
||||
if len(keys) < 4:
|
||||
from transformers import pipeline as summary_pipeline
|
||||
summarizer = summary_pipeline("summarization")
|
||||
#text to summarize:
|
||||
if len(koboldai_vars.actions) < 5:
|
||||
text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt])
|
||||
else:
|
||||
text = "".join(koboldai_vars.actions[:-5])
|
||||
global old_transfomers_functions
|
||||
temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria']
|
||||
keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']]
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp
|
||||
|
||||
art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting',
|
||||
|
||||
b64_data = text2img(", ".join(keys), art_guide = art_guide)
|
||||
emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)})
|
||||
|
||||
|
||||
@logger.catch
|
||||
@@ -8189,9 +8230,9 @@ def text2img(prompt,
|
||||
else:
|
||||
final_filename = filename
|
||||
img.save(final_filename)
|
||||
return(img)
|
||||
print("Saved Image")
|
||||
koboldai_vars.generating_image = False
|
||||
return(b64img)
|
||||
else:
|
||||
koboldai_vars.generating_image = False
|
||||
print(submit_req.text)
|
||||
|
Reference in New Issue
Block a user