diff --git a/aiserver.py b/aiserver.py index 8fbac449..ef66f82d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -13,12 +13,15 @@ import json import requests import html import argparse +import sys +import gc # KoboldAI import fileops import gensettings from utils import debounce import utils +import breakmodel #==================================================================# # Variables & Storage @@ -100,6 +103,8 @@ class vars: saveow = False # Whether or not overwrite confirm has been displayed genseqs = [] # Temporary storage for generated sequences useprompt = True # Whether to send the full prompt with every submit action + breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only + bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J only, currently) acregex_ai = re.compile(r'\n* *>(.|\n)*') # Pattern for matching adventure actions from the AI so we can remove them acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) actionmode = 1 @@ -160,6 +165,8 @@ parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI fo parser.add_argument("--model", help="Specify the Model Type to skip the Menu") parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)") parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.") +parser.add_argument("--breakmodel", action='store_true', help="For models that support GPU-CPU hybrid generation, use this feature instead of GPU or CPU generation") +parser.add_argument("--breakmodel_layers", type=int, help="Specify the number of layers to commit to system RAM if --breakmodel is used") args = parser.parse_args() vars.model = args.model; @@ -184,6 +191,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): import torch print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() + vars.bmsupported = vars.model in ("EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", "NeoCustom") if(vars.hascuda): print("{0}FOUND!{1}".format(colors.GREEN, colors.END)) else: @@ -193,23 +201,40 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(vars.hascuda): genselected = True vars.usegpu = True + vars.breakmodel = False if(args.cpu): vars.usegpu = False - elif(vars.hascuda): - print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END)) - print(" 1 - GPU\n 2 - CPU\n") + vars.breakmodel = False + if(vars.bmsupported and args.breakmodel): + vars.usegpu = False + vars.breakmodel = True + elif(vars.hascuda): + if(vars.bmsupported): + print(colors.YELLOW + "You're using a model that supports GPU-CPU hybrid generation!\nCurrently only GPT-Neo models and GPT-J-6B support this feature.") + print("{0}Use GPU or CPU for generation?: (Default GPU){1}".format(colors.CYAN, colors.END)) + if(vars.bmsupported): + print(f" 1 - GPU\n 2 - CPU\n 3 - Both (slower than GPU-only but uses less VRAM)\n") + else: + print(" 1 - GPU\n 2 - CPU\n") genselected = False if(vars.hascuda): while(genselected == False): genselect = input("Mode> ") if(genselect == ""): + vars.breakmodel = False vars.usegpu = True genselected = True elif(genselect.isnumeric() and int(genselect) == 1): + vars.breakmodel = False vars.usegpu = True genselected = True elif(genselect.isnumeric() and int(genselect) == 2): + vars.breakmodel = False + vars.usegpu = False + genselected = True + elif(vars.bmsupported and genselect.isnumeric() and int(genselect) == 3): + vars.breakmodel = True vars.usegpu = False genselected = True else: @@ -343,15 +368,48 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END)) if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) - from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM + from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModel # If custom GPT Neo model was chosen if(vars.model == "NeoCustom"): model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth) tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth) # Is CUDA available? If so, use GPU, otherwise fall back to CPU - if(vars.hascuda and vars.usegpu): - generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0) + if(vars.hascuda): + if(vars.usegpu): + generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0) + elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) + n_layers = model.config.num_layers + breakmodel.total_blocks = n_layers + model.half().to('cpu') + gc.collect() + model.transformer.wte.to(breakmodel.gpu_device) + model.transformer.ln_f.to(breakmodel.gpu_device) + if(hasattr(model, 'lm_head')): + model.lm_head.to(breakmodel.gpu_device) + if(not hasattr(model.config, 'rotary') or not model.config.rotary): + model.transformer.wpe.to(breakmodel.gpu_device) + gc.collect() + if(args.breakmodel_layers is not None): + breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) + else: + print(colors.CYAN + "\nHow many layers would you like to put into system RAM?") + print("The more of them you put into system RAM, the slower it will run,") + print("but it will require less VRAM") + print("(roughly proportional to number of layers).") + print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n") + while(True): + layerselect = input("# of layers> ") + if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers): + breakmodel.ram_blocks = int(layerselect) + break + else: + print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}") + print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}") + GPTNeoModel.forward = breakmodel.new_forward + generator = model.generate + else: + generator = pipeline('text-generation', model=model, tokenizer=tokenizer) else: generator = pipeline('text-generation', model=model, tokenizer=tokenizer) # If custom GPT2 model was chosen @@ -367,8 +425,42 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): else: # Is CUDA available? If so, use GPU, otherwise fall back to CPU tokenizer = GPT2Tokenizer.from_pretrained(vars.model) - if(vars.hascuda and vars.usegpu): - generator = pipeline('text-generation', model=vars.model, device=0) + if(vars.hascuda): + if(vars.usegpu): + generator = pipeline('text-generation', model=vars.model, device=0) + elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel) + model = AutoModel.from_pretrained(vars.model) + n_layers = model.config.num_layers + breakmodel.total_blocks = n_layers + model.half().to('cpu') + gc.collect() + model.transformer.wte.to(breakmodel.gpu_device) + model.transformer.ln_f.to(breakmodel.gpu_device) + if(hasattr(model, 'lm_head')): + model.lm_head.to(breakmodel.gpu_device) + if(not hasattr(model.config, 'rotary') or not model.config.rotary): + model.transformer.wpe.to(breakmodel.gpu_device) + gc.collect() + if(args.breakmodel_layers is not None): + breakmodel.ram_blocks = max(0, min(n_layers, args.breakmodel_layers)) + else: + print(colors.CYAN + "\nHow many layers would you like to put into system RAM?") + print("The more of them you put into system RAM, the slower it will run,") + print("but it will require less VRAM") + print("(roughly proportional to number of layers).") + print(f"This model has{colors.YELLOW} {n_layers} {colors.CYAN}layers.{colors.END}\n") + while(True): + layerselect = input("# of layers> ") + if(layerselect.isnumeric() and 0 <= int(layerselect) <= n_layers): + breakmodel.ram_blocks = int(layerselect) + break + else: + print(f"{colors.RED}Please enter an integer between 0 and {n_layers}.{colors.END}") + print(f"{colors.PURPLE}Will commit{colors.YELLOW} {breakmodel.ram_blocks} {colors.PURPLE}of{colors.YELLOW} {n_layers} {colors.PURPLE}layers to system RAM.{colors.END}") + GPTNeoModel.forward = breakmodel.new_forward + generator = model.generate + else: + generator = pipeline('text-generation', model=vars.model) else: generator = pipeline('text-generation', model=vars.model) @@ -480,42 +572,42 @@ def get_message(msg): elif(msg['cmd'] == 'settemp'): vars.temp = float(msg['data']) emit('from_server', {'cmd': 'setlabeltemp', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'settopp'): vars.top_p = float(msg['data']) emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'settopk'): vars.top_k = int(msg['data']) emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'settfs'): vars.tfs = float(msg['data']) emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'setoutput'): vars.genamt = int(msg['data']) emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'settknmax'): vars.max_length = int(msg['data']) emit('from_server', {'cmd': 'setlabeltknmax', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'setikgen'): vars.ikgen = int(msg['data']) emit('from_server', {'cmd': 'setlabelikgen', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() # Author's Note field update elif(msg['cmd'] == 'anote'): @@ -524,28 +616,28 @@ def get_message(msg): elif(msg['cmd'] == 'anotedepth'): vars.andepth = int(msg['data']) emit('from_server', {'cmd': 'setlabelanotedepth', 'data': msg['data']}, broadcast=True) - settingschanged() + settingschanged() refresh_settings() # Format - Trim incomplete sentences elif(msg['cmd'] == 'frmttriminc'): if('frmttriminc' in vars.formatoptns): vars.formatoptns["frmttriminc"] = msg['data'] - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'frmtrmblln'): if('frmtrmblln' in vars.formatoptns): vars.formatoptns["frmtrmblln"] = msg['data'] - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'frmtrmspch'): if('frmtrmspch' in vars.formatoptns): vars.formatoptns["frmtrmspch"] = msg['data'] - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'frmtadsnsp'): if('frmtadsnsp' in vars.formatoptns): vars.formatoptns["frmtadsnsp"] = msg['data'] - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'importselect'): vars.importnum = int(msg["data"].replace("import", "")) @@ -589,20 +681,20 @@ def get_message(msg): elif(msg['cmd'] == 'setnumseq'): vars.numseqs = int(msg['data']) emit('from_server', {'cmd': 'setlabelnumseq', 'data': msg['data']}) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'setwidepth'): vars.widepth = int(msg['data']) emit('from_server', {'cmd': 'setlabelwidepth', 'data': msg['data']}) - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'setuseprompt'): vars.useprompt = msg['data'] - settingschanged() + settingschanged() refresh_settings() elif(msg['cmd'] == 'setadventure'): vars.adventure = msg['data'] - settingschanged() + settingschanged() refresh_settings() refresh_story() elif(msg['cmd'] == 'importwi'): @@ -984,7 +1076,8 @@ def generate(txt, min, max): vars.lastctx = txt # Clear CUDA cache if using GPU - if(vars.hascuda and vars.usegpu): + if(vars.hascuda and (vars.usegpu or vars.breakmodel)): + gc.collect() torch.cuda.empty_cache() # Submit input text to generator @@ -992,35 +1085,50 @@ def generate(txt, min, max): top_p = vars.top_p if vars.top_p > 0.0 else None top_k = vars.top_k if vars.top_k > 0 else None tfs = vars.tfs if vars.tfs > 0.0 else None + + # generator() only accepts a torch tensor of tokens (long datatype) as + # its first argument if we're using breakmodel, otherwise a string + # is fine + if(vars.hascuda and vars.breakmodel): + gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(breakmodel.gpu_device) + else: + gen_in = txt - genout = generator( - txt, - do_sample=True, - min_length=min, - max_length=max, - repetition_penalty=vars.rep_pen, - top_p=top_p, - top_k=top_k, - tfs=tfs, - temperature=vars.temp, - bad_words_ids=vars.badwordsids, - use_cache=True, - return_full_text=False, - num_return_sequences=vars.numseqs - ) + with torch.no_grad(): + genout = generator( + gen_in, + do_sample=True, + min_length=min, + max_length=max, + repetition_penalty=vars.rep_pen, + top_p=top_p, + top_k=top_k, + tfs=tfs, + temperature=vars.temp, + bad_words_ids=vars.badwordsids, + use_cache=True, + return_full_text=False, + num_return_sequences=vars.numseqs + ) except Exception as e: emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) print("{0}{1}{2}".format(colors.RED, e, colors.END)) set_aibusy(0) return + # Need to manually strip and decode tokens if we're not using a pipeline + if(vars.hascuda and vars.breakmodel): + genout = [{"generated_text": tokenizer.decode(tokens[len(gen_in[0])-len(tokens):])} for tokens in genout] + if(len(genout) == 1): genresult(genout[0]["generated_text"]) else: genselect(genout) # Clear CUDA cache again if using GPU - if(vars.hascuda and vars.usegpu): + if(vars.hascuda and (vars.usegpu or vars.breakmodel)): + del genout + gc.collect() torch.cuda.empty_cache() set_aibusy(0) @@ -1966,4 +2074,4 @@ if __name__ == "__main__": else: import webbrowser webbrowser.open_new('http://localhost:5000') - socketio.run(app) \ No newline at end of file + socketio.run(app) diff --git a/breakmodel.py b/breakmodel.py new file mode 100644 index 00000000..8154b623 --- /dev/null +++ b/breakmodel.py @@ -0,0 +1,488 @@ +''' +This is a MODIFIED version of arrmansa's low VRAM patch. +https://github.com/arrmansa/Basic-UI-for-GPT-J-6B-with-low-vram/blob/main/GPT-J-6B-Low-Vram-UI.ipynb +Copyright 2021 arrmansa +Copyright 2021 finetuneanon +Copyright 2018 The Hugging Face team +Released under the Apache License 2.0 + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +''' + + +import torch +import torch.cuda.comm +import copy +import gc + +from transformers.modeling_outputs import BaseModelOutputWithPast + +from transformers.utils import logging +logger = logging.get_logger(__name__) + + +class MaxSharedRamBlocksException(Exception): + def __init__(self, i: int): + self.corrected_max_shared_ram_blocks = i + super().__init__('max_shared_ram_blocks is set too high, please set it to '+str(i)) + + +breakmodel = True +gpu_device = 'cuda' +total_blocks = 24 +ram_blocks = 7 +max_shared_ram_blocks = None + + +def new_forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embs=None, + ): + global max_shared_ram_blocks + + if breakmodel: + if max_shared_ram_blocks is None: + max_shared_ram_blocks = total_blocks + + if not hasattr(self, 'extrastorage'): + setattr(self,"extrastorage",{}) + torch.cuda.empty_cache() + + for i in range(ram_blocks,len(self.h)): + self.h[i].to(gpu_device) + + for i in range(ram_blocks): + self.h[i].to("cpu") + self.extrastorage[i] = copy.deepcopy(self.h[i]) + smalltensor = torch.tensor(0).to(gpu_device) + for param1 in self.h[i].parameters(): + param1.data = smalltensor + self.h[i].to(gpu_device) + + for i in range(len(self.h)): + for param in self.h[i].parameters(): + param.requires_grad = False + param.data = param.data.detach() + gc.collect() + torch.cuda.empty_cache() + + for i in range(ram_blocks): + for param in self.extrastorage[i].parameters(): + param.requires_grad = False + if i < max_shared_ram_blocks: + try: + param.data = param.data.detach().pin_memory() + except: + raise MaxSharedRamBlocksException(i) + else: + param.data = param.data.detach() + gc.collect() + torch.cuda.empty_cache() + + for param1,param2 in zip(self.h[0].parameters(),self.extrastorage[0].parameters()): + param1.data = param2.data.to(gpu_device, non_blocking=False).detach() + + for param1,param2 in zip(self.h[ram_blocks-1].parameters(),self.extrastorage[ram_blocks-1].parameters()): + param1.data = param2.data.to(gpu_device, non_blocking=False).detach() + #END MODEL BREAK EDITS + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + global_attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + global_attention_mask = global_attention_mask[:, None, None, :] + + # Since global_attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + global_attention_mask = global_attention_mask.to(dtype=self.dtype) # fp16 compatibility + global_attention_mask = (1.0 - global_attention_mask) * -10000.0 + else: + global_attention_mask = None + + # Local causal attention mask + batch_size, seq_length = input_shape + full_seq_length = seq_length + past_length + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None): + offset = 0 + for pos, emb in embs: + pos += offset + if len(emb.shape) == 2: + emb = emb.repeat(input_shape[0], 1, 1) + inputs_embeds[:, pos:pos+emb.shape[1]] = emb + offset += emb.shape[1] + + if hasattr(self, 'rotary') and self.rotary: + hidden_states = inputs_embeds + else: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + + if breakmodel: + copystream = torch.cuda.Stream(device=0,priority = -1) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if breakmodel: + if i in range(ram_blocks): + index1 = (i+1)%ram_blocks + for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()): + param1.data = param2.data + for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()): + with torch.cuda.stream(copystream): + torch.cuda.comm.broadcast(param2.data,out = [param1.data]) + + + attn_type = self.config.attention_layers[i] + attn_mask = global_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states.cpu(),) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attn_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attn_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + + if breakmodel: + if i in range(ram_blocks): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + if breakmodel: + del copystream + + torch.cuda.empty_cache() + + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + )