diff --git a/aiserver.py b/aiserver.py index c65a9e84..1b2e87d7 100644 --- a/aiserver.py +++ b/aiserver.py @@ -64,7 +64,7 @@ from utils import debounce import utils import koboldai_settings import torch -from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils +from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification from transformers import __version__ as transformers_version import transformers try: @@ -73,6 +73,11 @@ except: pass import transformers.generation_utils +# Text2img +import base64 +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps +from io import BytesIO + global tpu_mtj_backend @@ -1618,8 +1623,8 @@ def get_cluster_models(msg): # If the client settings file doesn't exist, create it # Write API key to file os.makedirs('settings', exist_ok=True) - if path.exists(get_config_filename(koboldai_vars.model_selected)): - with open(get_config_filename(koboldai_vars.model_selected), "r") as file: + if path.exists(get_config_filename(model)): + with open(get_config_filename(model), "r") as file: js = json.load(file) if 'online_model' in js: online_model = js['online_model'] @@ -1630,7 +1635,7 @@ def get_cluster_models(msg): changed=True if changed: js={} - with open(get_config_filename(koboldai_vars.model_selected), "w") as file: + with open(get_config_filename(model), "w") as file: js["apikey"] = koboldai_vars.oaiapikey file.write(json.dumps(js, indent=3)) @@ -1674,7 +1679,7 @@ def patch_transformers_download(): if bar != "": try: - print(bar, end="\r") + print(bar, end="\n") emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") eventlet.sleep(seconds=0) except: @@ -1712,10 +1717,12 @@ def patch_transformers_download(): desc=f"Downloading {file_name}" if file_name is not None else "Downloading", file=Send_to_socketio(), ) + koboldai_vars.total_download_chunks = total for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks if url[-11:] != 'config.json': progress.update(len(chunk)) + koboldai_vars.downloaded_chunks += len(chunk) temp_file.write(chunk) if url[-11:] != 'config.json': progress.close() @@ -1768,6 +1775,8 @@ def patch_transformers_download(): def patch_transformers(): global transformers + global old_transfomers_functions + old_transfomers_functions = {} patch_transformers_download() @@ -1784,9 +1793,11 @@ def patch_transformers(): if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) + old_transfomers_functions['PreTrainedModel.from_pretrained'] = PreTrainedModel.from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained if(hasattr(modeling_utils, "get_checkpoint_shard_files")): old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files + old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): utils.num_shards = utils.get_num_shards(index_filename) utils.from_pretrained_index_filename = index_filename @@ -1814,6 +1825,7 @@ def patch_transformers(): if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward XGLMSinusoidalPositionalEmbedding.forward = new_forward @@ -1833,6 +1845,7 @@ def patch_transformers(): self.model = OPTModel(config) self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.post_init() + old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__ OPTForCausalLM.__init__ = new_init @@ -2117,6 +2130,7 @@ def patch_transformers(): break return self.regeneration_required or self.halt old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria + old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria def new_get_stopping_criteria(self, *args, **kwargs): stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) global tokenizer @@ -2171,7 +2185,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal if not utils.HAS_ACCELERATE: disk_layers = None koboldai_vars.reset_model() - koboldai_vars.cluster_requested_models = online_model + koboldai_vars.cluster_requested_models = [online_model] if isinstance(online_model, str) else online_model koboldai_vars.noai = False if not use_breakmodel_args: set_aibusy(True) @@ -8134,6 +8148,138 @@ def get_model_size(model_name): def UI_2_save_revision(data): koboldai_vars.save_revision() + +#==================================================================# +# Generate Image +#==================================================================# +@socketio.on("generate_image") +def UI_2_generate_image(data): + koboldai_vars.generating_image = True + #get latest action + if len(koboldai_vars.actions) > 0: + action = koboldai_vars.actions[-1] + else: + action = koboldai_vars.prompt + #Get matching world info entries + keys = [] + for wi in koboldai_vars.worldinfo_v2: + for key in wi['key']: + if key in action: + #Check to make sure secondary keys are present if needed + if len(wi['keysecondary']) > 0: + for keysecondary in wi['keysecondary']: + if keysecondary in action: + keys.append(key) + break + break + else: + keys.append(key) + break + + + #If we have > 4 keys, use those otherwise use sumarization + if len(keys) < 4: + from transformers import pipeline as summary_pipeline + summarizer = summary_pipeline("summarization") + #text to summarize: + if len(koboldai_vars.actions) < 5: + text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt]) + else: + text = "".join(koboldai_vars.actions[:-5]) + global old_transfomers_functions + temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria + transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] + keys = [summarizer(text, max_length=100, min_length=30, do_sample=False)[0]['summary_text']] + transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp + + art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting', + + b64_data = text2img(", ".join(keys), art_guide = art_guide) + emit("Action_Image", {'b64': b64_data, 'prompt': ", ".join(keys)}) + + +@logger.catch +def text2img(prompt, + art_guide = 'fantasy illustration, artstation, by jason felix by steve argyle by tyler jacobson by peter mohrbacher, cinematic lighting', + filename = "story_art.png"): + print("Generating Image") + koboldai_vars.generating_image = True + final_imgen_params = { + "n": 1, + "width": 512, + "height": 512, + "steps": 50, + } + + final_submit_dict = { + "prompt": "{}, {}".format(prompt, art_guide), + "api_key": koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000", + "params": final_imgen_params, + } + logger.debug(final_submit_dict) + submit_req = requests.post('https://stablehorde.net/api/v1/generate/sync', json = final_submit_dict) + if submit_req.ok: + results = submit_req.json() + for iter in range(len(results)): + b64img = results[iter]["img"] + base64_bytes = b64img.encode('utf-8') + img_bytes = base64.b64decode(base64_bytes) + img = Image.open(BytesIO(img_bytes)) + if len(results) > 1: + final_filename = f"{iter}_{filename}" + else: + final_filename = filename + img.save(final_filename) + print("Saved Image") + koboldai_vars.generating_image = False + return(b64img) + else: + koboldai_vars.generating_image = False + print(submit_req.text) + +def get_items_locations_from_text(text): + # load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") + model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") + nlp = transformers.pipeline("ner", model=model, tokenizer=tokenizer) + # input example sentence + ner_results = nlp(text) + orgs = [] + last_org_position = -2 + loc = [] + last_loc_position = -2 + per = [] + last_per_position = -2 + for i, result in enumerate(ner_results): + if result['entity'] in ('B-ORG', 'I-ORG'): + if result['start']-1 <= last_org_position: + if result['start'] != last_org_position: + orgs[-1] = "{} ".format(orgs[-1]) + orgs[-1] = "{}{}".format(orgs[-1], result['word'].replace("##", "")) + else: + orgs.append(result['word']) + last_org_position = result['end'] + elif result['entity'] in ('B-LOC', 'I-LOC'): + if result['start']-1 <= last_loc_position: + if result['start'] != last_loc_position: + loc[-1] = "{} ".format(loc[-1]) + loc[-1] = "{}{}".format(loc[-1], result['word'].replace("##", "")) + else: + loc.append(result['word']) + last_loc_position = result['end'] + elif result['entity'] in ('B-PER', 'I-PER'): + if result['start']-1 <= last_per_position: + if result['start'] != last_per_position: + per[-1] = "{} ".format(per[-1]) + per[-1] = "{}{}".format(per[-1], result['word'].replace("##", "")) + else: + per.append(result['word']) + last_per_position = result['end'] + + print("Orgs: {}".format(orgs)) + print("Locations: {}".format(loc)) + print("People: {}".format(per)) + #==================================================================# # Test #==================================================================# @@ -10919,6 +11065,7 @@ if __name__ == "__main__": try: cloudflare = str(localtunnel.stdout.readline()) cloudflare = (re.search("(?Phttps?:\/\/[^\s]+loca.lt)", cloudflare).group("url")) + koboldai_vars.cloudflare_link = cloudflare break except: attempts += 1 @@ -10928,12 +11075,15 @@ if __name__ == "__main__": print("LocalTunnel could not be created, falling back to cloudflare...") from flask_cloudflared import _run_cloudflared cloudflare = _run_cloudflared(port) + koboldai_vars.cloudflare_link = cloudflare elif(args.ngrok): from flask_ngrok import _run_ngrok cloudflare = _run_ngrok() + koboldai_vars.cloudflare_link = cloudflare elif(args.remote): from flask_cloudflared import _run_cloudflared cloudflare = _run_cloudflared(port) + koboldai_vars.cloudflare_link = cloudflare if(args.localtunnel or args.ngrok or args.remote): with open('cloudflare.log', 'w') as cloudflarelog: cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare) diff --git a/environments/huggingface.yml b/environments/huggingface.yml index 25894ec8..72871c81 100644 --- a/environments/huggingface.yml +++ b/environments/huggingface.yml @@ -20,6 +20,7 @@ dependencies: - marshmallow>=3.13 - apispec-webframeworks - loguru + - Pillow - pip: - flask-cloudflared - flask-ngrok diff --git a/environments/rocm.yml b/environments/rocm.yml index fef38892..a0334a9a 100644 --- a/environments/rocm.yml +++ b/environments/rocm.yml @@ -17,6 +17,7 @@ dependencies: - marshmallow>=3.13 - apispec-webframeworks - loguru + - Pillow - pip: - --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html - torch==1.10.* diff --git a/koboldai_settings.py b/koboldai_settings.py index a68ecbbe..83f538a7 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -111,7 +111,7 @@ class koboldai_vars(object): def reset_model(self): self._model_settings.reset_for_model_load() - def calc_ai_text(self, submitted_text="", method=2): + def calc_ai_text(self, submitted_text="", method=2, return_text=False): context = [] token_budget = self.max_length used_world_info = [] @@ -285,6 +285,8 @@ class koboldai_vars(object): tokens = self.tokenizer.encode(text) self.context = context + if return_text: + return text return tokens, used_tokens, used_tokens+self.genamt def __setattr__(self, name, value): @@ -493,13 +495,16 @@ class model_settings(settings): if self.tqdm.format_dict['rate'] is not None: self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_layers-self.loaded_layers)/self.tqdm.format_dict['rate']))) #Setup TQDP for model downloading + elif name == "total_download_chunks" and 'tqdm' in self.__dict__: + self.tqdm.reset(total=value) + self.tqdm_progress = 0 elif name == "downloaded_chunks" and 'tqdm' in self.__dict__: if value == 0: self.tqdm.reset(total=self.total_download_chunks) self.tqdm_progress = 0 else: self.tqdm.update(value-old_value) - self.tqdm_progress = round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1) + self.tqdm_progress = 0 if self.total_download_chunks==0 else round(float(self.downloaded_chunks)/float(self.total_download_chunks)*100, 1) if self.tqdm.format_dict['rate'] is not None: self.tqdm_rem_time = str(datetime.timedelta(seconds=int(float(self.total_download_chunks-self.downloaded_chunks)/self.tqdm.format_dict['rate']))) @@ -738,7 +743,6 @@ class system_settings(settings): self.userscripts = [] # List of userscripts to load self.last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems self.corescript = "default.lua" # Filename of corescript to load - self.gpu_device = 0 # Which PyTorch device to use when using pure GPU generation self.savedir = os.getcwd()+"\\stories" self.hascuda = False # Whether torch has detected CUDA on the system @@ -794,6 +798,8 @@ class system_settings(settings): print("Colab Check: {}".format(self.on_colab)) self.horde_share = False self._horde_pid = None + self.sh_apikey = "" # API key to use for txt2img from the Stable Horde. + self.generating_image = False #The current status of image generation self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost @@ -877,6 +883,8 @@ class KoboldStoryRegister(object): temp = [self.actions[x]["Selected Text"] for x in list(self.actions)[i]] return temp else: + if i < 0: + return self.actions[self.action_count+i+1]["Selected Text"] return self.actions[i]["Selected Text"] def __setitem__(self, i, text): diff --git a/requirements.txt b/requirements.txt index 3f49630f..c9b316c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ accelerate flask_session marshmallow>=3.13 apispec-webframeworks -loguru \ No newline at end of file +loguru +Pillow \ No newline at end of file diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 2067273a..4c1c4ca7 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -19,4 +19,5 @@ bleach==4.1.0 flask-session marshmallow>=3.13 apispec-webframeworks -loguru \ No newline at end of file +loguru +Pillow \ No newline at end of file diff --git a/static/koboldai.css b/static/koboldai.css index f4b8fa56..c7ed66b7 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -2201,6 +2201,10 @@ button.disabled { color: red; } +.italics { + font-style: italic; +} + .within_max_length { color: var(--text_to_ai_color); font-weight: bold; @@ -2370,4 +2374,14 @@ h2 .material-icons-outlined { input[type='range'] { border: none !important; +} + +.settings_button[system_generating_image="true"] { + filter: brightness(35%); + cursor: not-allowed; + pointer-events:none; +} + +.action_image { + width: var(--flyout_menu_width); } \ No newline at end of file diff --git a/static/koboldai.js b/static/koboldai.js index 24202c12..5c103df9 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -29,6 +29,7 @@ socket.on('load_cookies', function(data){load_cookies(data)}); socket.on('load_tweaks', function(data){load_tweaks(data);}); socket.on("wi_results", updateWISearchListings); socket.on("request_prompt_config", configurePrompt); +socket.on("Action_Image", function(data){Action_Image(data);}); //socket.onAny(function(event_name, data) {console.log({"event": event_name, "class": data.classname, "data": data});}); var presets = {}; @@ -1429,8 +1430,8 @@ function load_model() { for (item of document.getElementById("oaimodel").selectedOptions) { selected_models.push(item.value); } - if (selected_models == []) { - selected_models = ""; + if (selected_models == ['']) { + selected_models = []; } else if (selected_models.length == 1) { selected_models = selected_models[0]; } @@ -1958,6 +1959,18 @@ function load_cookies(data) { } } +function Action_Image(data) { + var image = new Image(); + image.src = 'data:image/png;base64,'+data['b64']; + image.setAttribute("title", data['prompt']); + image.classList.add("action_image"); + image_area = document.getElementById("action image"); + while (image_area.firstChild) { + image_area.removeChild(image_area.firstChild); + } + image_area.appendChild(image); +} + //--------------------------------------------UI to Server Functions---------------------------------- function unload_userscripts() { files_to_unload = document.getElementById('loaded_userscripts'); @@ -2917,7 +2930,8 @@ function assign_world_info_to_action(action_item, uid) { //console.log(null); var before_span = document.createElement("span"); before_span.textContent = before_highlight_text; - var hightlight_span = document.createElement("i"); + var hightlight_span = document.createElement("span"); + hightlight_span.classList.add("italics"); hightlight_span.textContent = highlight_text; hightlight_span.title = worldinfo['content']; var after_span = document.createElement("span"); @@ -2977,7 +2991,8 @@ function assign_world_info_to_action(action_item, uid) { //console.log(null); var before_span = document.createElement("span"); before_span.textContent = before_highlight_text; - var hightlight_span = document.createElement("i"); + var hightlight_span = document.createElement("span"); + hightlight_span.classList.add("italics"); hightlight_span.textContent = highlight_text; hightlight_span.title = worldinfo['content']; var after_span = document.createElement("span"); diff --git a/templates/settings flyout.html b/templates/settings flyout.html index 71ef97a4..49895852 100644 --- a/templates/settings flyout.html +++ b/templates/settings flyout.html @@ -105,6 +105,10 @@ Download debug dump +
+ +
+