diff --git a/aiserver.py b/aiserver.py index 96d0c23c..b38bfd8e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -113,6 +113,34 @@ def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs): return tokenizer PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained +# We only want to use logit manipulations and such on our core text model +class use_core_manipulations: + # These must be set by wherever they get setup + get_logits_processor: callable + sample: callable + get_stopping_criteria: callable + + # We set these automatically + old_get_logits_processor: callable + old_sample: callable + old_get_stopping_criteria: callable + + def __enter__(self): + use_core_manipulations.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor + transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.get_logits_processor + + use_core_manipulations.old_sample = transformers.generation_utils.GenerationMixin.sample + transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.sample + + use_core_manipulations.old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria + transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.get_stopping_criteria + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + transformers.generation_utils.GenerationMixin._get_logits_processor = use_core_manipulations.old_get_logits_processor + transformers.generation_utils.GenerationMixin.sample = use_core_manipulations.old_sample + transformers.generation_utils.GenerationMixin._get_stopping_criteria = use_core_manipulations.old_get_stopping_criteria + #==================================================================# # Variables & Storage #==================================================================# @@ -1910,8 +1938,6 @@ def patch_transformers_download(): def patch_transformers(): global transformers - global old_transfomers_functions - old_transfomers_functions = {} patch_transformers_download() @@ -1933,7 +1959,6 @@ def patch_transformers(): PreTrainedModel._kai_patched = True if(hasattr(modeling_utils, "get_checkpoint_shard_files")): old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files - old_transfomers_functions['modeling_utils.get_checkpoint_shard_files'] = old_get_checkpoint_shard_files def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): utils.num_shards = utils.get_num_shards(index_filename) utils.from_pretrained_index_filename = index_filename @@ -1961,7 +1986,6 @@ def patch_transformers(): if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() - old_transfomers_functions['XGLMSinusoidalPositionalEmbedding.forward'] = XGLMSinusoidalPositionalEmbedding.forward XGLMSinusoidalPositionalEmbedding.forward = new_forward @@ -1981,7 +2005,6 @@ def patch_transformers(): self.model = OPTModel(config) self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.post_init() - old_transfomers_functions['OPTForCausalLM.__init__'] = OPTForCausalLM.__init__ OPTForCausalLM.__init__ = new_init @@ -2170,8 +2193,8 @@ def patch_transformers(): processors.append(PhraseBiasLogitsProcessor()) processors.append(ProbabilityVisualizerLogitsProcessor()) return processors + use_core_manipulations.get_logits_processor = new_get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor - transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor class KoboldLogitsWarperList(LogitsProcessorList): def __init__(self, beams: int = 1, **kwargs): @@ -2204,9 +2227,9 @@ def patch_transformers(): kwargs["eos_token_id"] = -1 kwargs.setdefault("pad_token_id", 2) return new_sample.old_sample(self, *args, **kwargs) - new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample - transformers.generation_utils.GenerationMixin.sample = new_sample + new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample + use_core_manipulations.sample = new_sample # Allow bad words filter to ban <|endoftext|> token import transformers.generation_logits_process @@ -2374,7 +2397,7 @@ def patch_transformers(): old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria - old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] = old_get_stopping_criteria + def new_get_stopping_criteria(self, *args, **kwargs): global tokenizer stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) @@ -2392,7 +2415,7 @@ def patch_transformers(): stopping_criteria.insert(0, token_streamer) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) return stopping_criteria - transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria + use_core_manipulations.get_stopping_criteria = new_get_stopping_criteria def reset_model_settings(): koboldai_vars.socketio = socketio @@ -5395,56 +5418,57 @@ def raw_generate( result: GenerationResult time_start = time.time() - if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): - batch_encoded = tpu_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new, - batch_count=batch_count, - gen_settings=gen_settings - ) - result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True - ) - elif koboldai_vars.model in model_functions: - batch_encoded = model_functions[koboldai_vars.model]( - prompt_tokens=prompt_tokens, - max_new=max_new, - batch_count=batch_count, - gen_settings=gen_settings - ) - result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True - ) - elif koboldai_vars.model.startswith("RWKV"): - batch_encoded = rwkv_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new, - batch_count=batch_count, - gen_settings=gen_settings - ) - result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True - ) - else: - # Torch HF - start_time = time.time() - batch_encoded = torch_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new if not bypass_hf_maxlength else int(2e9), - do_streaming=do_streaming, - do_dynamic_wi=do_dynamic_wi, - batch_count=batch_count, - gen_settings=gen_settings - ) - logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time)) - start_time = time.time() - result = GenerationResult( - out_batches=batch_encoded, - prompt=prompt_tokens, - is_whole_generation=False, - output_includes_prompt=True, - ) - logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time)) + with use_core_manipulations(): + if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): + batch_encoded = tpu_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + ) + elif koboldai_vars.model in model_functions: + batch_encoded = model_functions[koboldai_vars.model]( + prompt_tokens=prompt_tokens, + max_new=max_new, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + ) + elif koboldai_vars.model.startswith("RWKV"): + batch_encoded = rwkv_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True + ) + else: + # Torch HF + start_time = time.time() + batch_encoded = torch_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new if not bypass_hf_maxlength else int(2e9), + do_streaming=do_streaming, + do_dynamic_wi=do_dynamic_wi, + batch_count=batch_count, + gen_settings=gen_settings + ) + logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time)) + start_time = time.time() + result = GenerationResult( + out_batches=batch_encoded, + prompt=prompt_tokens, + is_whole_generation=False, + output_includes_prompt=True, + ) + logger.debug("raw_generate: run GenerationResult {}s".format(time.time()-start_time)) time_end = round(time.time() - time_start, 2) tokens_per_second = round(len(result.encoded[0]) / time_end, 2) @@ -8077,8 +8101,8 @@ def file_popup(popup_title, starting_folder, return_event, upload=True, jailed=T def get_files_folders(starting_folder): import stat session['current_folder'] = os.path.abspath(starting_folder).replace("\\", "/") - item_check = session['popup_item_check'] - extra_parameter_function = session['extra_parameter_function'] + item_check = globals()[session['popup_item_check']] if session['popup_item_check'] is not None else None + extra_parameter_function = globals()[session['extra_parameter_function']] if session['extra_parameter_function'] is not None else None show_breadcrumbs = session['popup_show_breadcrumbs'] show_hidden = session['popup_show_hidden'] folder_only = session['popup_folder_only'] @@ -8090,7 +8114,7 @@ def get_files_folders(starting_folder): sort = session['sort'] desc = session['desc'] show_folders = session['show_folders'] - advanced_sort = session['advanced_sort'] + advanced_sort = globals()[session['advanced_sort']] if session['advanced_sort'] is not None else None if starting_folder == 'This PC': breadcrumbs = [['This PC', 'This PC']] @@ -8444,10 +8468,10 @@ def UI_2_load_model(data): @logger.catch def UI_2_load_story_list(data): file_popup("Select Story to Load", "./stories", "load_story", upload=True, jailed=True, folder_only=False, renameable=True, - deleteable=True, show_breadcrumbs=True, item_check=valid_story, - valid_only=True, hide_extention=True, extra_parameter_function=get_story_listing_data, + deleteable=True, show_breadcrumbs=True, item_check="valid_story", + valid_only=True, hide_extention=True, extra_parameter_function="get_story_listing_data", column_names=['Story Name', 'Action Count', 'Last Loaded'], show_filename=False, - column_widths=['minmax(150px, auto)', '140px', '160px'], advanced_sort=story_sort, + column_widths=['minmax(150px, auto)', '140px', '160px'], advanced_sort="story_sort", sort="Modified", desc=True, rename_return_emit_name="popup_rename_story") @logger.catch @@ -8809,8 +8833,8 @@ def UI_2_load_softprompt_list(data): socketio.emit("error", "Soft prompts are not supported by your current model/backend", broadcast=True, room="UI_2") assert koboldai_vars.allowsp, "Soft prompts are not supported by your current model/backend" file_popup("Select Softprompt to Load", "./softprompts", "load_softprompt", upload=True, jailed=True, folder_only=False, renameable=True, - deleteable=True, show_breadcrumbs=True, item_check=valid_softprompt, - valid_only=True, hide_extention=True, extra_parameter_function=get_softprompt_desc, + deleteable=True, show_breadcrumbs=True, item_check="valid_softprompt", + valid_only=True, hide_extention=True, extra_parameter_function="get_softprompt_desc", column_names=['Softprompt Name', 'Softprompt Description'], show_filename=False, column_widths=['150px', 'auto']) @@ -8852,8 +8876,8 @@ def UI_2_load_softprompt(data): @logger.catch def UI_2_load_userscripts_list(data): file_popup("Select Userscripts to Load", "./userscripts", "load_userscripts", upload=True, jailed=True, folder_only=False, renameable=True, editable=True, - deleteable=True, show_breadcrumbs=False, item_check=valid_userscripts_to_load, - valid_only=True, hide_extention=True, extra_parameter_function=get_userscripts_desc, + deleteable=True, show_breadcrumbs=False, item_check="valid_userscripts_to_load", + valid_only=True, hide_extention=True, extra_parameter_function="get_userscripts_desc", column_names=['Module Name', 'Description'], show_filename=False, show_folders=False, column_widths=['150px', 'auto']) @@ -9199,24 +9223,33 @@ def text2img_horde(prompt, filename = "story_art.png"): logger.debug("Generating Image using Horde") koboldai_vars.generating_image = True - final_imgen_params = { - "n": 1, - "width": 512, - "height": 512, - "steps": 50, - } + final_submit_dict = { "prompt": "{}, {}".format(prompt, art_guide), - "api_key": koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000", - "params": final_imgen_params, + "trusted_workers": False, + "models": [ + "stable_diffusion" + ], + "params": { + "n":1, + "nsfw": True, + "sampler_name": "k_euler_a", + "karras": True, + "cfg_scale": 7.0, + "steps":25, + "width":512, + "height":512} } + + cluster_headers = {'apikey': koboldai_vars.sh_apikey if koboldai_vars.sh_apikey != '' else "0000000000",} + logger.debug(final_submit_dict) - submit_req = requests.post('https://stablehorde.net/api/v1/generate/sync', json = final_submit_dict) + submit_req = requests.post('https://stablehorde.net/api/v2/generate/sync', json = final_submit_dict, headers=cluster_headers) if submit_req.ok: results = submit_req.json() - for iter in range(len(results)): - b64img = results[iter]["img"] + for iter in range(len(results['generations'])): + b64img = results['generations'][iter]["img"] base64_bytes = b64img.encode('utf-8') img_bytes = base64.b64decode(base64_bytes) img = Image.open(BytesIO(img_bytes)) @@ -9285,7 +9318,7 @@ def text2img_api(prompt, "prompt": "{}, {}".format(prompt, art_guide), "params": final_imgen_params, } - apiaddress = 'http://127.0.0.1:7860/sdapi/v1/txt2img' + apiaddress = '{}/sdapi/v1/txt2img'.format(koboldai_vars.img_gen_api_url) payload_json = json.dumps(final_submit_dict) logger.debug(final_submit_dict) submit_req = requests.post(url=f'{apiaddress}', data=payload_json).json() @@ -9384,14 +9417,10 @@ def summarize(text, max_length=100, min_length=30, unload=True): #Actual sumarization start_time = time.time() - global old_transfomers_functions - temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria - transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] #make sure text is less than 1024 tokens, otherwise we'll crash if len(koboldai_vars.summary_tokenizer.encode(text)) > 1000: text = koboldai_vars.summary_tokenizer.decode(koboldai_vars.summary_tokenizer.encode(text)[:1000]) output = tpool.execute(summarizer, text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text'] - transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp logger.debug("Time to summarize: {}".format(time.time()-start_time)) #move model back to CPU to save precious vram torch.cuda.empty_cache() diff --git a/gensettings.py b/gensettings.py index b398e88b..44be5dd5 100644 --- a/gensettings.py +++ b/gensettings.py @@ -554,6 +554,19 @@ gensettingstf = [ 'children': [{'text': 'Use Local Only', 'value': 0}, {'text':'Prefer Local','value':1}, {'text':'Prefer Horde', 'value':2}, {'text':'Use Horde Only', 'value':3}, {'text':'Use Local SD-WebUI API', 'value':4}] }, { + "UI_V2_Only": True, + "uitype": "text", + "unit": "text", + "label": "Img API URL", + "id": "img_gen_api_url", + "default": "", + "tooltip": "The URL to use when selecting Use Local SD-WebUI API setting in Image Priority", + "menu_path": "Interface", + "sub_path": "Images", + "classname": "user", + "name": "img_gen_api_url" + }, + { "UI_V2_Only": True, "uitype": "toggle", "unit": "bool", diff --git a/koboldai_settings.py b/koboldai_settings.py index e9be418b..4249c604 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -985,6 +985,7 @@ class user_settings(settings): self.beep_on_complete = False self.img_gen_priority = 1 self.show_budget = False + self.img_gen_api_url = "http://127.0.0.1:7860/" self.cluster_requested_models = [] # The models which we allow to generate during cluster mode @@ -1287,7 +1288,7 @@ class KoboldStoryRegister(object): if type(json_data) == str: import json json_data = json.loads(json_data) - self.action_count = json_data['action_count'] + self.action_count = int(json_data['action_count']) #JSON forces keys to be strings, so let's fix that temp = {} data_to_send = [] diff --git a/static/koboldai.js b/static/koboldai.js index 4c056c6a..932400c9 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -1784,6 +1784,13 @@ function world_info_entry(data) { //First let's get the id of the element we're on so we can restore it after removing the object var original_focus = document.activeElement.id; + if (!(document.getElementById("world_info_folder_"+data.folder))) { + folder = document.createElement("div"); + //console.log("Didn't find folder " + data.folder); + } else { + folder = document.getElementById("world_info_folder_"+data.folder); + } + if (document.getElementById("world_info_"+data.uid)) { world_info_card = document.getElementById("world_info_"+data.uid); } else { @@ -1791,6 +1798,7 @@ function world_info_entry(data) { world_info_card = world_info_card_template.cloneNode(true); world_info_card.id = "world_info_"+data.uid; world_info_card.setAttribute("uid", data.uid); + folder.append(world_info_card); } if (data.used_in_game) { world_info_card.classList.add("used_in_game"); @@ -2049,12 +2057,6 @@ function world_info_entry(data) { constant.checked = data.constant; constant.classList.remove("pulse"); - if (!(document.getElementById("world_info_folder_"+data.folder))) { - folder = document.createElement("div"); - //console.log("Didn't find folder " + data.folder); - } else { - folder = document.getElementById("world_info_folder_"+data.folder); - } //Let's figure out the order to insert this card var found = false; var moved = false; @@ -2888,6 +2890,11 @@ function push_selection_to_world_info() { menu.classList.add("open"); } document.getElementById("story_flyout_tab_wi").onclick(); + + if (~("root" in world_info_folder_data)) { + world_info_folder_data["root"] = []; + world_info_folder(world_info_folder_data); + } create_new_wi_entry("root"); document.getElementById("world_info_entry_text_-1").value = getSelectionText(); } diff --git a/templates/settings item.html b/templates/settings item.html index e8371916..c2aa3902 100644 --- a/templates/settings item.html +++ b/templates/settings item.html @@ -1,9 +1,9 @@ {% for item in settings %} {% if item["menu_path"] == menu and item['sub_path'] == sub_path %} {% if 'extra_classes' in item %} -