diff --git a/aiserver.py b/aiserver.py index 2498f82f..04245558 100644 --- a/aiserver.py +++ b/aiserver.py @@ -64,7 +64,7 @@ from utils import debounce import utils import koboldai_settings import torch -from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification +from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, AutoModelForTokenClassification from transformers import __version__ as transformers_version import transformers try: @@ -8162,7 +8162,8 @@ def UI_2_save_revision(data): #==================================================================# @socketio.on("generate_image") def UI_2_generate_image(data): - socketio.start_background_task(generate_image_in_background) + koboldai_vars.generating_image = True + tpool.execute(generate_image_in_background) def generate_image_in_background(): koboldai_vars.generating_image = True @@ -8192,23 +8193,44 @@ def generate_image_in_background(): if len(keys) < 4: from transformers import pipeline as summary_pipeline start_time = time.time() - if koboldai_vars.summarizer is None: - koboldai_vars.summarizer = summary_pipeline("summarization", model=args.summarizer_model) - #koboldai_vars.summarizer = summary_pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", cache="./cache") - logger.debug("Time to load summarizer: {}".format(time.time()-start_time)) - #text to summarize: - start_time = time.time() if len(koboldai_vars.actions) < 5: text = "".join(koboldai_vars.actions[:-5]+[koboldai_vars.prompt]) else: text = "".join(koboldai_vars.actions[:-5]) + + + + if koboldai_vars.summarizer is None: + if os.path.exists("models/{}".format(args.summarizer_model.replace('/', '_'))): + koboldai_vars.summary_tokenizer = AutoTokenizer.from_pretrained("models/{}".format(args.summarizer_model.replace('/', '_')), cache_dir="cache") + koboldai_vars.summarizer = AutoModelForSeq2SeqLM.from_pretrained("models/{}".format(args.summarizer_model.replace('/', '_')), cache_dir="cache") + else: + koboldai_vars.summary_tokenizer = AutoTokenizer.from_pretrained(args.summarizer_model, cache_dir="cache") + koboldai_vars.summarizer = AutoModelForSeq2SeqLM.from_pretrained(args.summarizer_model, cache_dir="cache") + koboldai_vars.summary_tokenizer.save_pretrained("models/{}".format(args.summarizer_model.replace('/', '_')), max_shard_size="500MiB") + koboldai_vars.summarizer.save_pretrained("models/{}".format(args.summarizer_model.replace('/', '_')), max_shard_size="500MiB") + + #Try GPU accel + if koboldai_vars.hascuda and torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0) >= 6000000000: + koboldai_vars.summarizer.to(0) + device=0 + else: + device="cpu" + summarizer = summary_pipeline(task="summarization", model=koboldai_vars.summarizer, tokenizer=koboldai_vars.summary_tokenizer, device=device) + logger.debug("Time to load summarizer: {}".format(time.time()-start_time)) + + #Actual sumarization + start_time = time.time() global old_transfomers_functions temp = transformers.generation_utils.GenerationMixin._get_stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = old_transfomers_functions['transformers.generation_utils.GenerationMixin._get_stopping_criteria'] - keys = [koboldai_vars.summarizer(text, max_length=args.max_summary_length, min_length=30, do_sample=False)[0]['summary_text']] + keys = [summarizer(text, max_length=args.max_summary_length, min_length=30, do_sample=False)[0]['summary_text']] transformers.generation_utils.GenerationMixin._get_stopping_criteria = temp logger.debug("Time to summarize: {}".format(time.time()-start_time)) + #move model back to CPU to save precious vram + koboldai_vars.summarizer.to("cpu") + torch.cuda.empty_cache() logger.debug("Original Text: {}".format(text)) logger.debug("Summarized Text: {}".format(keys[0])) @@ -8251,6 +8273,7 @@ def text2img_local(prompt, art_guide="", filename="new.png"): else: pipe = koboldai_vars.image_pipeline.to("cuda") logger.debug("time to load: {}".format(time.time() - start_time)) + start_time = time.time() from torch import autocast with autocast("cuda"): image = pipe(prompt, num_inference_steps=35)["sample"][0] @@ -8258,6 +8281,7 @@ def text2img_local(prompt, art_guide="", filename="new.png"): image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode('ascii') logger.debug("time to generate: {}".format(time.time() - start_time)) + start_time = time.time() if koboldai_vars.keep_img_gen_in_memory: pipe.to("cpu") if koboldai_vars.image_pipeline is None: diff --git a/koboldai_settings.py b/koboldai_settings.py index 4df71624..37095d36 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -729,8 +729,8 @@ class user_settings(settings): process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value) class system_settings(settings): - local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'image_pipeline', 'summarizer'] - no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'image_pipeline', 'summarizer'] + local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'image_pipeline', 'summarizer', 'summary_tokenizer'] + no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'image_pipeline', 'summarizer', 'summary_tokenizer'] settings_name = "system" def __init__(self, socketio): self.socketio = socketio @@ -810,6 +810,7 @@ class system_settings(settings): self.generating_image = False #The current status of image generation self.image_pipeline = None self.summarizer = None + self.summary_tokenizer = None self.keep_img_gen_in_memory = False self.cookies = {} #cookies for colab since colab's URL changes, cookies are lost