diff --git a/aiserver.py b/aiserver.py index 434f98e7..e5395878 100644 --- a/aiserver.py +++ b/aiserver.py @@ -45,6 +45,7 @@ import sys import gc import lupa +import importlib # KoboldAI import fileops @@ -53,13 +54,21 @@ from utils import debounce import utils import structures import torch -from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils +from transformers import __version__ as transformers_version +import transformers +try: + from transformers.models.opt.modeling_opt import OPTDecoder +except: + pass +import transformers.generation_utils global tpu_mtj_backend if lupa.LUA_VERSION[:2] != (5, 4): print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr) +patch_causallm_patched = False # Make sure tqdm progress bars display properly in Colab from tqdm.auto import tqdm @@ -255,7 +264,8 @@ class vars: last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems corescript = "default.lua" # Filename of corescript to load # badwords = [] # Array of str/chr values that should be removed from output - badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting + badwordsids = [] + badwordsids_default = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]] badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]] fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format @@ -349,14 +359,40 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END)) #==================================================================# # Function to get model selection at startup #==================================================================# -def sendModelSelection(menu="mainmenu"): +def sendModelSelection(menu="mainmenu", folder="./models"): #If we send one of the manual load options, send back the list of model directories, otherwise send the menu if menu in ('NeoCustom', 'GPT2Custom'): - menu_list = [[folder, menu, "", False] for folder in next(os.walk('./models'))[1]] + (paths, breadcrumbs) = get_folder_path_info(folder) + menu_list = [[folder, menu, "", False] for folder in paths] menu_list.append(["Return to Main Menu", "mainmenu", "", True]) - emit('from_server', {'cmd': 'show_model_menu', 'data': menu_list, 'menu': 'custom'}, broadcast=True) + emit('from_server', {'cmd': 'show_model_menu', 'data': menu_list, 'menu': menu, 'breadcrumbs': breadcrumbs}, broadcast=True) else: - emit('from_server', {'cmd': 'show_model_menu', 'data': model_menu[menu], 'menu': menu}, broadcast=True) + emit('from_server', {'cmd': 'show_model_menu', 'data': model_menu[menu], 'menu': menu, 'breadcrumbs': []}, broadcast=True) + +def get_folder_path_info(base): + if base == 'This PC': + breadcrumbs = [['This PC', 'This PC']] + paths = [["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))] + else: + path = os.path.abspath(base) + if path[-1] == "\\": + path = path[:-1] + breadcrumbs = [] + for i in range(len(path.split("\\"))): + breadcrumbs.append(["\\".join(path.split("\\")[:i+1]), + path.split("\\")[i]]) + if len(breadcrumbs) == 1: + breadcrumbs = [["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))] + else: + if len([["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]) > 0: + breadcrumbs.insert(0, ['This PC', 'This PC']) + paths = [] + base_path = os.path.abspath(base) + for item in os.listdir(base_path): + if os.path.isdir(os.path.join(base_path, item)): + paths.append([os.path.join(base_path, item), item]) + # Paths/breadcrumbs is a list of lists, where the first element in the sublist is the full path and the second is the folder name + return (paths, breadcrumbs) def getModelSelection(modellist): print(" # Model\t\t\t\t\t\tVRAM\n ========================================================") @@ -395,6 +431,15 @@ def getModelSelection(modellist): print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END)) getModelSelection(mainmenu) +def check_if_dir_is_model(path): + try: + from transformers import AutoConfig + model_config = AutoConfig.from_pretrained(path, revision=vars.revision, cache_dir="cache") + except: + return False + return True + + #==================================================================# # Return all keys in tokenizer dictionary containing char #==================================================================# @@ -1024,12 +1069,12 @@ def get_layer_count(model, directory=""): # Get the model_type from the config or assume a model type if it isn't present else: from transformers import AutoConfig - if vars.custmodpth == "": + if directory == "": model_config = AutoConfig.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") - elif(os.path.isdir("models/{}".format(vars.custmodpth.replace('/', '_')))): - model_config = AutoConfig.from_pretrained("models/{}".format(vars.custmodpth.replace('/', '_')), revision=vars.revision, cache_dir="cache") + elif(os.path.isdir(directory)): + model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache") else: model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") @@ -1094,12 +1139,267 @@ def get_oai_models(key): emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) +def patch_transformers(): + global transformers + old_from_pretrained = PreTrainedModel.from_pretrained.__func__ + @classmethod + def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + vars.fp32_model = False + utils.num_shards = None + utils.current_shard = 0 + utils.from_pretrained_model_name = pretrained_model_name_or_path + utils.from_pretrained_index_filename = None + utils.from_pretrained_kwargs = kwargs + utils.bar = None + if not args.no_aria2: + utils.aria2_hook(pretrained_model_name_or_path, **kwargs) + return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) + PreTrainedModel.from_pretrained = new_from_pretrained + if(hasattr(modeling_utils, "get_checkpoint_shard_files")): + old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files + def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): + utils.num_shards = utils.get_num_shards(index_filename) + utils.from_pretrained_index_filename = index_filename + return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) + modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files + + # Some versions of transformers 4.17.0.dev0 are affected by + # https://github.com/huggingface/transformers/issues/15736 + # This is a workaround for those versions of transformers. + if(transformers_version == "4.17.0.dev0"): + try: + from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding + except ImportError: + pass + else: + @torch.no_grad() + def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0): + bsz, seq_len = inputs_embeds.size()[:-1] + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + position_ids = torch.arange( + past_key_values_length + self.padding_idx + 1, past_key_values_length + sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ).unsqueeze(0).expand(input_shape).contiguous() + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + XGLMSinusoidalPositionalEmbedding.forward = new_forward + + # Patch transformers to use our soft prompt + def patch_causallm(cls): + old_forward = cls.forward + def new_causallm_forward(self, *args, **kwargs): + input_ids = kwargs.get('input_ids').to(self.device) + assert input_ids is not None + kwargs['input_ids'] = None + if(vars.sp is not None): + shifted_input_ids = input_ids - self.config.vocab_size + input_ids.clamp_(max=self.config.vocab_size-1) + if(hasattr(self, "transformer")): + inputs_embeds = self.transformer.wte(input_ids) + elif(not hasattr(self.model, "decoder")): + inputs_embeds = self.model.embed_tokens(input_ids) + else: + inputs_embeds = self.model.decoder.embed_tokens(input_ids) + if(vars.sp is not None): + vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) + inputs_embeds = torch.where( + (shifted_input_ids >= 0)[..., None], + vars.sp[shifted_input_ids.clamp(min=0)], + inputs_embeds, + ) + if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): + inputs_embeds *= self.model.embed_scale + kwargs['inputs_embeds'] = inputs_embeds + return old_forward(self, *args, **kwargs) + cls.forward = new_causallm_forward + for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): + patch_causallm(cls) + for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"): + try: + patch_causallm(getattr(__import__("transformers"), c)) + except: + pass + + + # Fix a bug in OPTForCausalLM where self.lm_head is the wrong size + if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) <= packaging.version.parse("4.19.2")): + try: + from transformers import OPTForCausalLM, OPTModel + except ImportError: + pass + else: + # This is the same as the original __init__ but with + # config.hidden_size + # replaced with + # config.word_embed_proj_dim + def new_init(self, config): + super(OPTForCausalLM, self).__init__(config) + self.model = OPTModel(config) + self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + self.post_init() + OPTForCausalLM.__init__ = new_init + + + # Patch transformers to use our custom logit warpers + from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper + + def dynamic_processor_wrap(cls, field_name, var_name, cond=None): + old_call = cls.__call__ + def new_call(self, *args, **kwargs): + if(not isinstance(field_name, str) and isinstance(field_name, Iterable)): + conds = [] + for f, v in zip(field_name, var_name): + conds.append(getattr(vars, v)) + setattr(self, f, conds[-1]) + else: + conds = getattr(vars, var_name) + setattr(self, field_name, conds) + assert len(args) == 2 + if(cond is None or cond(conds)): + return old_call(self, *args, **kwargs) + return args[1] + cls.__call__ = new_call + dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) + dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) + dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) + dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) + dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) + RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ + RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ + + class LuaLogitsProcessor(LogitsProcessor): + + def __init__(self): + pass + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + assert scores.ndim == 2 + assert input_ids.ndim == 2 + self.regeneration_required = False + self.halt = False + + scores_shape = scores.shape + scores_list = scores.tolist() + vars.lua_koboldbridge.logits = vars.lua_state.table() + for r, row in enumerate(scores_list): + vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) + vars.lua_koboldbridge.vocab_size = scores_shape[-1] + + execute_genmod() + + scores = torch.tensor( + tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), + device=scores.device, + dtype=scores.dtype, + ) + assert scores.shape == scores_shape + + return scores + + def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: + processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) + processors.insert(0, LuaLogitsProcessor()) + return processors + new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor + transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor + + def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: + warper_list = LogitsProcessorList() + warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + return warper_list + + def new_sample(self, *args, **kwargs): + assert kwargs.pop("logits_warper", None) is not None + kwargs["logits_warper"] = new_get_logits_warper( + beams=1, + ) + if(vars.newlinemode == "s") or (vars.newlinemode == "ns"): + kwargs["eos_token_id"] = -1 + kwargs.setdefault("pad_token_id", 2) + return new_sample.old_sample(self, *args, **kwargs) + new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample + transformers.generation_utils.GenerationMixin.sample = new_sample + + + # Allow bad words filter to ban <|endoftext|> token + import transformers.generation_logits_process + def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int): + return new_init.old_init(self, bad_words_ids, -1) + new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ + transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init + + + # Sets up dynamic world info scanner + class DynamicWorldInfoScanCriteria(StoppingCriteria): + def __init__( + self, + tokenizer, + excluded_world_info: List[Set], + ): + self.regeneration_required = False + self.halt = False + self.tokenizer = tokenizer + self.excluded_world_info = excluded_world_info + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + vars.generated_tkns += 1 + if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols): + raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})") + if(vars.abort or vars.generated_tkns >= vars.genamt): + self.regeneration_required = False + self.halt = False + return True + + assert input_ids.ndim == 2 + assert len(self.excluded_world_info) == input_ids.shape[0] + self.regeneration_required = vars.lua_koboldbridge.regeneration_required + self.halt = not vars.lua_koboldbridge.generating + vars.lua_koboldbridge.regeneration_required = False + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(input_ids[i, -1].item()) + + if(not vars.dynamicscan): + return self.regeneration_required or self.halt + tail = input_ids[..., -vars.generated_tkns:] + for i, t in enumerate(tail): + decoded = utils.decodenewlines(tokenizer.decode(t)) + _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) + found -= self.excluded_world_info[i] + if(len(found) != 0): + self.regeneration_required = True + break + return self.regeneration_required or self.halt + old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria + def new_get_stopping_criteria(self, *args, **kwargs): + stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) + global tokenizer + self.kai_scanner = DynamicWorldInfoScanCriteria( + tokenizer=tokenizer, + excluded_world_info=self.kai_scanner_excluded_world_info, + ) + stopping_criteria.insert(0, self.kai_scanner) + return stopping_criteria + transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=""): global model global generator global torch global model_config + print("Loading vars.model: {} vars.custmodpth: {}".format(vars.model, vars.custmodpth)) vars.noai = False if not initial_load: set_aibusy(True) @@ -1113,10 +1413,14 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" #We need to wipe out the existing model and refresh the cuda cache model = None generator = None + model_config = None try: torch.cuda.empty_cache() except: pass + + #Reload our badwords + vars.badwordsids = vars.badwordsids_default #Let's set the GooseAI or OpenAI server URLs if that's applicable if online_model != "": @@ -1266,42 +1570,11 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) - from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer for m in ("GPTJModel", "XGLMModel"): try: globals()[m] = getattr(__import__("transformers"), m) except: pass - try: - from transformers.models.opt.modeling_opt import OPTDecoder - except: - pass - import transformers.generation_utils - from transformers import __version__ as transformers_version - - from transformers import PreTrainedModel - from transformers import modeling_utils - old_from_pretrained = PreTrainedModel.from_pretrained.__func__ - @classmethod - def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - vars.fp32_model = False - utils.num_shards = None - utils.current_shard = 0 - utils.from_pretrained_model_name = pretrained_model_name_or_path - utils.from_pretrained_index_filename = None - utils.from_pretrained_kwargs = kwargs - utils.bar = None - if not args.no_aria2: - utils.aria2_hook(pretrained_model_name_or_path, **kwargs) - return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) - PreTrainedModel.from_pretrained = new_from_pretrained - if(hasattr(modeling_utils, "get_checkpoint_shard_files")): - old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files - def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): - utils.num_shards = utils.get_num_shards(index_filename) - utils.from_pretrained_index_filename = index_filename - return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) - modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files # Lazy loader import torch_lazy_loader @@ -1398,236 +1671,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" else: vars.lazy_load = False - # Some versions of transformers 4.17.0.dev0 are affected by - # https://github.com/huggingface/transformers/issues/15736 - # This is a workaround for those versions of transformers. - if(transformers_version == "4.17.0.dev0"): - try: - from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding - except ImportError: - pass - else: - @torch.no_grad() - def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0): - bsz, seq_len = inputs_embeds.size()[:-1] - input_shape = inputs_embeds.size()[:-1] - sequence_length = input_shape[1] - position_ids = torch.arange( - past_key_values_length + self.padding_idx + 1, past_key_values_length + sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device - ).unsqueeze(0).expand(input_shape).contiguous() - max_pos = self.padding_idx + 1 + seq_len + past_key_values_length - if max_pos > self.weights.size(0): - self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) - return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() - XGLMSinusoidalPositionalEmbedding.forward = new_forward - - # Patch transformers to use our soft prompt - def patch_causallm(cls): - old_forward = cls.forward - def new_causallm_forward(self, *args, **kwargs): - input_ids = kwargs.get('input_ids').to(self.device) - assert input_ids is not None - kwargs['input_ids'] = None - if(vars.sp is not None): - shifted_input_ids = input_ids - self.config.vocab_size - input_ids.clamp_(max=self.config.vocab_size-1) - if(hasattr(self, "transformer")): - inputs_embeds = self.transformer.wte(input_ids) - elif(not hasattr(self.model, "decoder")): - inputs_embeds = self.model.embed_tokens(input_ids) - else: - inputs_embeds = self.model.decoder.embed_tokens(input_ids) - if(vars.sp is not None): - vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) - inputs_embeds = torch.where( - (shifted_input_ids >= 0)[..., None], - vars.sp[shifted_input_ids.clamp(min=0)], - inputs_embeds, - ) - if(hasattr(self, "model") and hasattr(self.model, "embed_scale")): - inputs_embeds *= self.model.embed_scale - kwargs['inputs_embeds'] = inputs_embeds - return old_forward(self, *args, **kwargs) - cls.forward = new_causallm_forward - for cls in (GPT2LMHeadModel, GPTNeoForCausalLM): - patch_causallm(cls) - for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"): - try: - patch_causallm(getattr(__import__("transformers"), c)) - except: - pass - - - # Fix a bug in OPTForCausalLM where self.lm_head is the wrong size - if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) <= packaging.version.parse("4.19.2")): - try: - from transformers import OPTForCausalLM, OPTModel - except ImportError: - pass - else: - # This is the same as the original __init__ but with - # config.hidden_size - # replaced with - # config.word_embed_proj_dim - def new_init(self, config): - super(OPTForCausalLM, self).__init__(config) - self.model = OPTModel(config) - self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) - self.post_init() - OPTForCausalLM.__init__ = new_init - - - # Patch transformers to use our custom logit warpers - from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper - - def dynamic_processor_wrap(cls, field_name, var_name, cond=None): - old_call = cls.__call__ - def new_call(self, *args, **kwargs): - if(not isinstance(field_name, str) and isinstance(field_name, Iterable)): - conds = [] - for f, v in zip(field_name, var_name): - conds.append(getattr(vars, v)) - setattr(self, f, conds[-1]) - else: - conds = getattr(vars, var_name) - setattr(self, field_name, conds) - assert len(args) == 2 - if(cond is None or cond(conds)): - return old_call(self, *args, **kwargs) - return args[1] - cls.__call__ = new_call - dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) - dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) - dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) - dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) - dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) - dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) - RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ - RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ - - class LuaLogitsProcessor(LogitsProcessor): - - def __init__(self): - pass - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - assert scores.ndim == 2 - assert input_ids.ndim == 2 - self.regeneration_required = False - self.halt = False - - scores_shape = scores.shape - scores_list = scores.tolist() - vars.lua_koboldbridge.logits = vars.lua_state.table() - for r, row in enumerate(scores_list): - vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) - vars.lua_koboldbridge.vocab_size = scores_shape[-1] - - execute_genmod() - - scores = torch.tensor( - tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), - device=scores.device, - dtype=scores.dtype, - ) - assert scores.shape == scores_shape - - return scores - def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: - processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) - processors.insert(0, LuaLogitsProcessor()) - return processors - new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor - transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor - - def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: - warper_list = LogitsProcessorList() - warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TemperatureLogitsWarper(temperature=0.5)) - return warper_list - - def new_sample(self, *args, **kwargs): - assert kwargs.pop("logits_warper", None) is not None - kwargs["logits_warper"] = new_get_logits_warper( - beams=1, - ) - if(vars.newlinemode == "s") or (vars.newlinemode == "ns"): - kwargs["eos_token_id"] = -1 - kwargs.setdefault("pad_token_id", 2) - return new_sample.old_sample(self, *args, **kwargs) - new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample - transformers.generation_utils.GenerationMixin.sample = new_sample - - - # Allow bad words filter to ban <|endoftext|> token - import transformers.generation_logits_process - def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int): - return new_init.old_init(self, bad_words_ids, -1) - new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ - transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init - - - # Sets up dynamic world info scanner - class DynamicWorldInfoScanCriteria(StoppingCriteria): - def __init__( - self, - tokenizer, - excluded_world_info: List[Set], - ): - self.regeneration_required = False - self.halt = False - self.tokenizer = tokenizer - self.excluded_world_info = excluded_world_info - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - **kwargs, - ) -> bool: - vars.generated_tkns += 1 - if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols): - raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})") - if(vars.abort or vars.generated_tkns >= vars.genamt): - self.regeneration_required = False - self.halt = False - return True - - assert input_ids.ndim == 2 - assert len(self.excluded_world_info) == input_ids.shape[0] - self.regeneration_required = vars.lua_koboldbridge.regeneration_required - self.halt = not vars.lua_koboldbridge.generating - vars.lua_koboldbridge.regeneration_required = False - - for i in range(vars.numseqs): - vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(input_ids[i, -1].item()) - - if(not vars.dynamicscan): - return self.regeneration_required or self.halt - tail = input_ids[..., -vars.generated_tkns:] - for i, t in enumerate(tail): - decoded = utils.decodenewlines(tokenizer.decode(t)) - _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) - found -= self.excluded_world_info[i] - if(len(found) != 0): - self.regeneration_required = True - break - return self.regeneration_required or self.halt - old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria - def new_get_stopping_criteria(self, *args, **kwargs): - stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) - global tokenizer - self.kai_scanner = DynamicWorldInfoScanCriteria( - tokenizer=tokenizer, - excluded_world_info=self.kai_scanner_excluded_world_info, - ) - stopping_criteria.insert(0, self.kai_scanner) - return stopping_criteria - transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria def get_hidden_size_from_model(model): try: @@ -2983,19 +3027,23 @@ def get_message(msg): # If we're on a custom line that we have selected a model for, the path variable will be in msg # so if that's missing we need to run the menu to show the model folders in the models folder if msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path' not in msg: - sendModelSelection(menu=msg['data']) - #elif msg['data'] in ('OAI', 'GooseAI'): - # vars.model = msg['data'] - # get_oai_models() - # emit('from_server', {'cmd': 'hide_layer_bar'}, broadcast=True) - # emit('from_server', {'cmd': 'check_enable_model_load', 'model': vars.model}, broadcast=True) + if 'folder' not in msg: + folder = "./models" + else: + folder = msg['folder'] + sendModelSelection(menu=msg['data'], folder=folder) + elif msg['data'] in ('NeoCustom', 'GPT2Custom'): + if check_if_dir_is_model(msg['path']): + vars.model = msg['data'] + vars.custmodpth = msg['path'] + get_model_info(msg['data'], directory=msg['path']) + else: + sendModelSelection(menu=msg['data'], folder=msg['path']) else: vars.model = msg['data'] if 'path' in msg: - if msg['data'] == 'NeoCustom': - get_model_info(vars.custmodpth, directory=msg['path']) - else: - get_model_info(vars.model, directory=msg['path']) + vars.custmodpth = msg['path'] + get_model_info(msg['data'], directory=msg['path']) else: get_model_info(vars.model) @@ -5677,6 +5725,7 @@ if __name__ == "__main__": print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END), flush=True) general_startup() + patch_transformers() #show_select_model_list() if vars.model == "" or vars.model is None: vars.model = "ReadOnly" @@ -5732,6 +5781,7 @@ if __name__ == "__main__": else: general_startup() + patch_transformers() #show_select_model_list() if vars.model == "" or vars.model is None: vars.model = "ReadOnly" diff --git a/static/application.js b/static/application.js index cfdd42ac..035cdddf 100644 --- a/static/application.js +++ b/static/application.js @@ -991,22 +991,41 @@ function hideUSPopup() { } -function buildLoadModelList(ar, menu) { +function buildLoadModelList(ar, menu, breadcrumbs) { disableButtons([load_model_accept]); loadmodelcontent.html(""); + $("#loadmodellistbreadcrumbs").html(""); var i; + for(i=0; i"+breadcrumbs[i][1]+""); + $("#model_breadcrumbs"+i).off("click").on("click", (function () { + return function () { + socket.send({'cmd': 'selectmodel', 'data': $(this).attr("name"), 'folder': $(this).attr("value")}); + disableButtons([load_model_accept]); + } + })(i)); + } for(i=0; i\
" + //if the menu item is a link to another menu if(ar[i][3]) { html = html + "" } else { + //this is a model html = html + "
" } + if (Array.isArray(ar[i][0])) { + full_path = ar[i][0][0]; + folder = ar[i][0][1]; + } else { + full_path = ""; + folder = ar[i][0]; + } html = html + "
\ -
\ -
"+ar[i][0]+"
\ +
\ +
"+folder+"
\
"+ar[i][2]+"
\
\
" @@ -1020,7 +1039,7 @@ function buildLoadModelList(ar, menu) { } })(i)); //If we're in the custom load menu (we need to send the path data back in that case) - } else if(menu == 'custom') { + } else if(['NeoCustom', 'GPT2Custom'].includes(menu)) { $("#loadmodel"+i).off("click").on("click", (function () { return function () { socket.send({'cmd': 'selectmodel', 'data': $(this).attr("name"), 'path': $(this).attr("pretty_name")}); @@ -2472,11 +2491,12 @@ $(document).ready(function(){ debug_area.addClass("hidden"); } } else if(msg.cmd == 'show_model_menu') { + console.log(msg) $("#use_gpu_div").addClass("hidden"); $("#modelkey").addClass("hidden"); $("#modellayers").addClass("hidden"); $("#oaimodel").addClass("hidden") - buildLoadModelList(msg.data, msg.menu); + buildLoadModelList(msg.data, msg.menu, msg.breadcrumbs); } else if(msg.cmd == 'selected_model_info') { enableButtons([load_model_accept]); $("#oaimodel").addClass("hidden") diff --git a/templates/index.html b/templates/index.html index 648fef27..ac3c322e 100644 --- a/templates/index.html +++ b/templates/index.html @@ -279,8 +279,8 @@
Select A Model To Load
-
-
Model
+
+
diff --git a/utils.py b/utils.py index bc085412..2f1e84f4 100644 --- a/utils.py +++ b/utils.py @@ -149,7 +149,7 @@ def decodenewlines(txt): # Returns number of layers given an HF model config #==================================================================# def num_layers(config): - return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers + return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None #==================================================================# # Downloads huggingface checkpoints using aria2c if possible