From 2db1f2f7bb4dea89fb69aff93f4f1207f2974ace Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 25 Jan 2022 15:05:21 -0500 Subject: [PATCH 01/23] AvrilAI-style repetition penalty test --- aiserver.py | 5 ++--- tpu_mtj_backend.py | 45 ++++++++++++++++++++------------------------- warpers.py | 2 +- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/aiserver.py b/aiserver.py index 64470d1e..5be7a17a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -722,8 +722,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) - RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ - RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ class LuaLogitsProcessor(LogitsProcessor): @@ -767,6 +765,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) return warper_list def new_sample(self, *args, **kwargs): @@ -2771,7 +2770,7 @@ def _generate(txt, minimum, maximum, found_entries): do_sample=True, min_length=minimum, max_length=int(2e9), - repetition_penalty=1.1, + repetition_penalty=1.0, bad_words_ids=vars.badwordsids, use_cache=True, num_return_sequences=numseqs diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 653f8cf1..e7632eba 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -149,7 +149,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_dynamic(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -245,6 +245,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_dynamic(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(np.uint32) def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): @@ -292,7 +293,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_static(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -387,6 +388,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) + logits = apply_repetition_penalty_static(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -400,17 +402,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ # Get the pseudo-random number generator key that will # be used by kobold_sample_dynamic to randomly pick a token sample_key, new_key = jax.random.split(sample_key, num=2) - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - logits = apply_repetition_penalty_dynamic( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -422,6 +413,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ next_token = kobold_sample_dynamic( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ) **sampler_options, ) # Remember what token was picked @@ -493,18 +492,6 @@ class PenalizingCausalTransformer(CausalTransformer): assert logits.shape == (1, config["n_vocab"]) # Flatten it into a 1D array to make it easier to use logits = logits[0] - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - if repetition_penalty is not None: - logits = apply_repetition_penalty_static( - logits, - generated, - repetition_penalty, - generated_index, - gen_length, - rpslope, - rprange, - ) # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero @@ -516,6 +503,14 @@ class PenalizingCausalTransformer(CausalTransformer): next_token = kobold_sample_static( sample_key, logits, + ( + generated, + repetition_penalty, + generated_index, + gen_length, + rpslope, + rprange, + ), **sampler_options, ) # Remember what token was picked diff --git a/warpers.py b/warpers.py index 07670f6d..122bc1cd 100644 --- a/warpers.py +++ b/warpers.py @@ -31,7 +31,7 @@ import torch from transformers import LogitsWarper, LogitsProcessor -class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor): +class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): def __init__(self, *args, **kwargs): pass From 0032462837aa8d64f41745aa0f943308d857068a Mon Sep 17 00:00:00 2001 From: ebolam Date: Sat, 13 Aug 2022 20:12:35 -0400 Subject: [PATCH 02/23] Fix for vars.model getting set on AI selection in the UI rather than when actually loaded --- aiserver.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/aiserver.py b/aiserver.py index 7adf0e0c..8a91b4f6 100644 --- a/aiserver.py +++ b/aiserver.py @@ -239,6 +239,7 @@ class vars: submission = "" # Same as above, but after applying input formatting lastctx = "" # The last context submitted to the generator model = "" # Model ID string chosen at startup + model_selected = "" #selected model in UI model_type = "" # Model Type (Automatically taken from the model config) noai = False # Runs the script without starting up the transformers pipeline aibusy = False # Stops submissions while the AI is working @@ -1474,11 +1475,11 @@ def get_layer_count(model, directory=""): else: from transformers import AutoConfig if directory == "": - model_config = AutoConfig.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") - elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): - model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") + model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache") elif(os.path.isdir(directory)): model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache") + elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): + model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") else: model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") @@ -3669,8 +3670,8 @@ def get_message(msg): changed = True if not utils.HAS_ACCELERATE: msg['disk_layers'] = "0" - if os.path.exists("settings/" + vars.model.replace('/', '_') + ".breakmodel"): - with open("settings/" + vars.model.replace('/', '_') + ".breakmodel", "r") as file: + if os.path.exists("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel"): + with open("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel", "r") as file: data = file.read().split('\n')[:2] if len(data) < 2: data.append("0") @@ -3678,14 +3679,15 @@ def get_message(msg): if gpu_layers == msg['gpu_layers'] and disk_layers == msg['disk_layers']: changed = False if changed: - if vars.model in ["NeoCustom", "GPT2Custom"]: + if vars.model_selected in ["NeoCustom", "GPT2Custom"]: filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(vars.custmodpth))) else: - filename = "settings/{}.breakmodel".format(vars.model.replace('/', '_')) + filename = "settings/{}.breakmodel".format(vars.model_selected.replace('/', '_')) f = open(filename, "w") f.write(str(msg['gpu_layers']) + '\n' + str(msg['disk_layers'])) f.close() vars.colaburl = msg['url'] + "/request" + vars.model = vars.model_selected load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model']) elif(msg['cmd'] == 'show_model'): print("Model Name: {}".format(getmodelname())) @@ -3710,18 +3712,18 @@ def get_message(msg): elif msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path_modelname' in msg: #Here the user entered custom text in the text box. This could be either a model name or a path. if check_if_dir_is_model(msg['path_modelname']): - vars.model = msg['data'] + vars.model_selected = msg['data'] vars.custmodpth = msg['path_modelname'] get_model_info(msg['data'], directory=msg['path']) else: - vars.model = msg['path_modelname'] + vars.model_selected = msg['path_modelname'] try: - get_model_info(vars.model) + get_model_info(vars.model_selected) except: emit('from_server', {'cmd': 'errmsg', 'data': "The model entered doesn't exist."}) elif msg['data'] in ('NeoCustom', 'GPT2Custom'): if check_if_dir_is_model(msg['path']): - vars.model = msg['data'] + vars.model_selected = msg['data'] vars.custmodpth = msg['path'] get_model_info(msg['data'], directory=msg['path']) else: @@ -3730,12 +3732,12 @@ def get_message(msg): else: sendModelSelection(menu=msg['data'], folder=msg['path']) else: - vars.model = msg['data'] + vars.model_selected = msg['data'] if 'path' in msg: vars.custmodpth = msg['path'] get_model_info(msg['data'], directory=msg['path']) else: - get_model_info(vars.model) + get_model_info(vars.model_selected) elif(msg['cmd'] == 'delete_model'): if "{}/models".format(os.getcwd()) in os.path.abspath(msg['data']) or "{}\\models".format(os.getcwd()) in os.path.abspath(msg['data']): if check_if_dir_is_model(msg['data']): From 137695106d6dc07f896aa332f6aee4d180dd0caf Mon Sep 17 00:00:00 2001 From: ebolam Date: Wed, 17 Aug 2022 18:03:48 -0400 Subject: [PATCH 03/23] Fix for gooseai --- aiserver.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 5bbabe15..ef785313 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1497,9 +1497,9 @@ def get_layer_count(model, directory=""): def get_oai_models(key): vars.oaiapikey = key - if vars.model == 'OAI': + if vars.model_selected == 'OAI': url = "https://api.openai.com/v1/engines" - elif vars.model == 'GooseAI': + elif vars.model_selected == 'GooseAI': url = "https://api.goose.ai/v1/engines" else: return @@ -1528,8 +1528,8 @@ def get_oai_models(key): # If the client settings file doesn't exist, create it # Write API key to file os.makedirs('settings', exist_ok=True) - if path.exists("settings/{}.settings".format(vars.model)): - with open("settings/{}.settings".format(vars.model), "r") as file: + if path.exists("settings/{}.settings".format(vars.model_selected)): + with open("settings/{}.settings".format(vars.model_selected), "r") as file: js = json.load(file) if 'online_model' in js: online_model = js['online_model'] @@ -1537,7 +1537,7 @@ def get_oai_models(key): if js['apikey'] != key: changed=True if changed: - with open("settings/{}.settings".format(vars.model), "w") as file: + with open("settings/{}.settings".format(vars.model_selected), "w") as file: js["apikey"] = key file.write(json.dumps(js, indent=3)) From b04a3a2fbbefb01b93bc8722df694bbd9f7886d7 Mon Sep 17 00:00:00 2001 From: Henk Date: Thu, 18 Aug 2022 23:10:19 +0200 Subject: [PATCH 04/23] Dismiss reload warning when needed --- static/application.js | 1 + 1 file changed, 1 insertion(+) diff --git a/static/application.js b/static/application.js index ccd90601..dd0fc413 100644 --- a/static/application.js +++ b/static/application.js @@ -2962,6 +2962,7 @@ $(document).ready(function(){ $("#showmodelnamecontainer").removeClass("hidden"); } else if(msg.cmd == 'hide_model_name') { $("#showmodelnamecontainer").addClass("hidden"); + $(window).off('beforeunload'); location.reload(); //console.log("Closing window"); } else if(msg.cmd == 'model_load_status') { From 10e3e64b0b2c432f5bb0ff076d61938884a8d80e Mon Sep 17 00:00:00 2001 From: ebolam Date: Thu, 18 Aug 2022 19:10:18 -0400 Subject: [PATCH 05/23] Update for execution time timer --- static/application.js | 4 ---- static/favicon.js | 6 ++++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/static/application.js b/static/application.js index ccd90601..8cb20aea 100644 --- a/static/application.js +++ b/static/application.js @@ -2445,10 +2445,6 @@ $(document).ready(function(){ } else if(msg.cmd == "updatechunk") { hideMessage(); game_text.attr('contenteditable', allowedit); - if (typeof submit_start !== 'undefined') { - $("#runtime")[0].innerHTML = `Generation time: ${Math.round((Date.now() - submit_start)/1000)} sec`; - delete submit_start; - } var index = msg.data.index; var html = msg.data.html; var existingChunk = game_text.children('#n' + index); diff --git a/static/favicon.js b/static/favicon.js index 180059ff..fb40ac84 100644 --- a/static/favicon.js +++ b/static/favicon.js @@ -2,6 +2,7 @@ var fav_icon2 = "data:image/x-icon;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAMAAAAoLQ9TAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAB+1BMVEUAAAAAAAAAAAAAAAAAAQAAAAAAAQAAAAAAAAASFhBBWD4iUyoFEwgFEwguUTM+VDoMFAwAAAA+elIudz8AAAAAAAA0MigyLyQAAAAbLh1LdElSbUoVMBkAAABAZ0M2fkUAAAABAQFMiGQraDkAAQANFxEGFQkLFg8EEAYAAAAsZDonZjUAAABCgVVAnFYrSjhEjFpFi1sdRScAAAAjOi8VMxx1dGOFgGYAAABOTEabmIdlYlQaGhgaGhddXFauqY5JRjoAAAAAAAABAQFGeExIl1lX0XRW0XRHi1RFe02vv5W31KFd1Hpc1Hpe1HvO1KvDvJlqZ1plYVOmoIVt1IFl1H7AuZp1cV9jX1AmSCw3Nzg7NmA1MTJuz4Bm1H5MST9HPl9BQEMgNiNXgWKiobFgXICDd5dfw3RZVnJiV3zGv9Bqf29Oj2G/v8hTTpGhl8dbxHVd0npiYoxhWJvIxtlcimZFn1lRclg9SkZNblZBeEpDbEZCa0ZBc0hLY1BAS1BdaV87j01Vx3FWynJSrGZOhlVasGtas2xatm1at21WnWJQm15WyXJQvmlavnBZrGlEYEJWe1RBWz9Um2BavXBgxn9XhllGY0RLaklXiFlTwG5OpmVSfFNMbUpGZEVLa0lShldEhVCChHiKiHvWz6/Kw6WWlZGAfmj///8kr0X+AAAARHRSTlMAASFrcAhxIjLb/vWvsPb+20b4+DFFyMkz2vf43CP9/m5y9vZysLGvsQn19mz+/tz4+NxHycr3+Ejb/vaxsPX+3TRtcBrzrrgAAAABYktHRKhQCDaSAAAAB3RJTUUH5gYJFyQy3tftxgAAAQBJREFUGNNjYGBgYGRiZmFlZWNmZ2SAAA5OLm4eXj5+AQ6ogKCQi6ubu4ensCCIxygiKubl7ePr6+cfIC4owcjAJCkVGBQc4usbGhYeIS0jy8AsFxkVHRPr6xsXn5CYJK/AoKiUnJKalg5UkZGZla2swsCqmpObl1/g61tYVFxSqsbKwKpeVl5RWVVdU1tX39CoocnAotXU3NLa1t7R2dXd06utwqCj6+vb1z9h4sRJk6f4+uopMLDrG0z1nTZ94sQZM31nGRrJMjBKGJvMnjN3wrz5CxaaCnKAvSNqtmjxkqXLlptbQP0iYmllbWNrZ+/gCBVgZHdS1GR1VpAFqQcApI0/jqlZOvEAAAAldEVYdGRhdGU6Y3JlYXRlADIwMjItMDYtMDlUMjM6MzY6NTArMDA6MDDi0xr+AAAAJXRFWHRkYXRlOm1vZGlmeQAyMDIyLTA2LTA5VDIzOjM2OjUwKzAwOjAwk46iQgAAAABJRU5ErkJggg=="; var fav_icon1 = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAMAAAAoLQ9TAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAB+FBMVEUAAAAAAAAAAAAAAAAAAAEAAAAAAQEAAAAAAAAUFRlLVGYrSWgHEBoHEBk3S19HUGMOExkAAABOcos7apIAAAAAAAA2Ly01KyoAAAAgKzdVaX9bZHIaKzwAAABKYHhDcZgAAAABAQFfgJY2XX0AAQEQFhoIEhwOFRgGDRUAAAAAAQE3W3cyWnwAAABSeJJRjLs1R1FVgaFWgJ4lPlMAAAAsOD4aLj55bm2Md3QAAABPSkmfko9pXlsbGRkbGRlfWlm1oJxMQkAAAAAAAAABAQFTb4tYibFtvPpWgKNScpC6s7nExtNzwPp1wPnZx8jMsKtuZGFoXVutmJODwfJ7wfbHr6p5a2hnW1gtQlI4ODk7N2A2LzWDvet8wPZPRkRHPl9CQUQlMTthe4+ko7RhXYGEeJhzsuJaVXRjWHzIwtNwfYddhqLCwcpTTpGimMhvsuVzv/djYpBgWJvLydxlgptVirdZbX1ASFZUaXtOb4xOZX1OZHxNa4ZRX21DSV5gaG9Je6lqsepstO1knclcfJxtoc5tpNFuptVup9ZnkbdgjrVss+xjpuBvrd9snspOW29jdI5LVmlkj7Vvrd54t+RlfptQXXJWZHtlf51oruNgmMFfdJBYZn1RXnRWZXthfZxSeZiGgYGOhYLdxb/RubWZlpWFd3T////2kwjgAAAARXRSTlMAASFrcAhxIjLb/vWvsPb+20b4+DFFyMkz2vf43CP9/m5y9vZysLGvsQlw9fZs/v7c+PjcR8nK9/hI2/72sbD1/t00bXBAFktiAAAAAWJLR0SnwLcrAwAAAAd0SU1FB+YGCRchHQhxJNoAAAD/SURBVBjTY2BgYGBkYmZhZWVjZmdkgAAOTi5uHl4+fgEOqICgkKubu7uHp7AgiMcoIirm5e3j4+Pr5y8uKMHIwCQpFRAYFOzjExIaFi4tI8vALBcRGRUd4+MTGxefkCivwKColJSckpoGVJGekZmlrMLAqpqdk5uX7+NTUFhUXKLGysCqXlpWXlFZVV1TW1ffoKHJoKXd2NTc0trW3tHZ1d2jo8Kgq+fj09vXP2HCxEmTfXz0FRjYDQyn+EydNmHC9Bk+M42MZRkYJUxMZ82e0z933vwFZoIcYO+Imi9ctHjJ0mUWllC/iFhZ29ja2Ts4OkEFGNmdFTVZXRRkQeoBhkE/Yj5NSZ4AAAAldEVYdGRhdGU6Y3JlYXRlADIwMjItMDYtMDlUMjM6MzM6MjgrMDA6MDA90JbEAAAAJXRFWHRkYXRlOm1vZGlmeQAyMDIyLTA2LTA5VDIzOjMzOjI4KzAwOjAwTI0ueAAAAABJRU5ErkJggg=="; var fav_icon = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAMAAAAoLQ9TAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAB8lBMVEUAAAAAAAAAAAAAAAABAAAAAAABAAAAAAAAAAAdEBB0Pz5rKCgaBwcZBwdkMzJxPDocDAwAAACLTU6SOzsAAAAAAAA9Mic/LyEAAAA6HByQUUaIVEY+GBgAAACAQkKaQUIAAAABAQGWXl9+NjYBAAAaEBAcCAgZDQ0WBQUAAAB3Nzd9MjIAAACTUVK7UVJRNTWhVVaeVldTJSUAAAA+LC0+GhuGcmCgf2EAAABUTESrl4NzYlEdGhcdGhdiXFbIqIhWRjcAAAAAAAABAQGUSkq1VVX6bW6oUVGXS0vmro7+uJn6c3T6dXX/yqPnu5F3aFhxYVG/oH/7gHv6enjeuJOEcFtzX01VLCs4ODk7NmA5MTH1gHr6e3hWSTxHPl9CQUQ/JCKPYGGko7RhXYGEeJjmcW9cVnFjWH3IwtOHb3CjXV3CwcpTTpGimMjlb3D4c3RmYI1gWJvLydybZWW+T0x+V1hRP0Z7U1WTSEiHRUWGRUSORkZuTlBRQVBwX2CvRkXtaGjvamrNYWKmU1PVZ2fXaGjbaWncaWnAX1+7W1vkYF/ja2zRZWV9QkGeVFN2Pz69XV3ia2zkeHmpWFd/REOJSUirWVjjaGjBYGCeUlKMSkl8QkGBRUSoVlWeUE2QgXeWiHr1zqjmw5+bl5KVe2T///8NZLRGAAAARHRSTlMAASFrcAhxIjLb/vWvsPb+20b4+DFFyMkz2vf43CP9/m5y9vZysLGvsQn19mz+/tz4+NxHycr3+Ejb/vaxsPX+3TRtcBrzrrgAAAABYktHRKUuuUovAAAAB3RJTUUH5gYJFzsfVlK/LQAAAP9JREFUGNNjYGBgYGRiZmFlZWNmZ2SAAA5OLm4eXj5+AQ6ogKCQi6ubm7uHsCCIxygiKubp5e3t7ePrJy4owcjAJCnlHxAY5O0dHBIaJi0jy8AsFx4RGRXt7R0TGxefIK/AoKiUmJSckgpUkZaekamswsCqmpWdk5vn7Z1fUFhUrMbKwKpeUlpWXlFZVV1TW1evocnAotXQ2NTc0trW3tHZ2KWtwqCj6+3d3dPb19c/YaK3t54CA7u+wSTvyVP6+qZO855uaCTLwChhbDJj5qzZc6bOnWcqyAH2jqjZ/AULFy1eYm4B9YuIpZW1ja2dvYMjVICR3UlRk9VZQRakHgAlRz6K4dvoSgAAACV0RVh0ZGF0ZTpjcmVhdGUAMjAyMi0wNi0wOVQyMzo1OTozMSswMDowMJt1iQMAAAAldEVYdGRhdGU6bW9kaWZ5ADIwMjItMDYtMDlUMjM6NTk6MzErMDA6MDDqKDG/AAAAAElFTkSuQmCC" +var submit_start; var favicon = { @@ -53,11 +54,16 @@ var favicon = { start_swap: function() { this.run = true; this.auto_swap(); + submit_start = Date.now(); }, stop_swap: function() { this.run = false; this.change(fav_icon); + if (typeof submit_start !== 'undefined') { + $("#runtime")[0].innerHTML = `Execution time: ${Math.round((Date.now() - submit_start)/1000)} sec`; + delete submit_start; + } }, docHead:document.getElementsByTagName("head")[0] From 081240fad14f73c5760b85285b9451d648a29d56 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 10:19:10 -0400 Subject: [PATCH 06/23] Add print in console for model downloading when using Aria2 --- utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils.py b/utils.py index 9a35c623..2db82610 100644 --- a/utils.py +++ b/utils.py @@ -177,6 +177,7 @@ class Send_to_socketio(object): def write(self, bar): time.sleep(0.01) try: + print(bar, end="\r") emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) except: pass From ec90b7606456a72825b4e1eb3b7e6237d3b7dbf7 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 10:19:10 -0400 Subject: [PATCH 07/23] Add print in console for model downloading when using Aria2 --- utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils.py b/utils.py index 9a35c623..49ab9987 100644 --- a/utils.py +++ b/utils.py @@ -176,7 +176,9 @@ from flask_socketio import emit class Send_to_socketio(object): def write(self, bar): time.sleep(0.01) + print("got bar data") try: + print(bar, end="\r") emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) except: pass From c55d1503ee2058aa7a63515dcfb02e86f8b5d131 Mon Sep 17 00:00:00 2001 From: henk717 Date: Fri, 19 Aug 2022 17:23:11 +0200 Subject: [PATCH 08/23] Skein 20B --- colab/TPU.ipynb | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/colab/TPU.ipynb b/colab/TPU.ipynb index 448880e7..22c35e08 100644 --- a/colab/TPU.ipynb +++ b/colab/TPU.ipynb @@ -66,7 +66,7 @@ "#@title <-- Select your model below and then click this to start KoboldAI\n", "#@markdown You can find a description of the models below along with instructions on how to start KoboldAI.\n", "\n", - "Model = \"Nerys 13B V2\" #@param [\"Nerys 13B V2\", \"Janeway 13B\", \"Shinen 13B\", \"Skein 6B\", \"Janeway 6B\", \"Adventure 6B\", \"Shinen 6B\", \"Lit 6B\", \"NeoX 20B\", \"OPT 13B\", \"Fairseq Dense 13B\", \"GPT-J-6B\"] {allow-input: true}\n", + "Model = \"Nerys 13B V2\" #@param [\"Nerys 13B V2\", \"Janeway 13B\", \"Shinen 13B\", \"Skein 20B\", \"Skein 6B\", \"Janeway 6B\", \"Adventure 6B\", \"Shinen 6B\", \"Lit 6B\", \"NeoX 20B\", \"OPT 13B\", \"Fairseq Dense 13B\", \"GPT-J-6B\"] {allow-input: true}\n", "Version = \"Official\" #@param [\"Official\", \"United\"] {allow-input: true}\n", "Provider = \"Cloudflare\" #@param [\"Localtunnel\", \"Cloudflare\"]\n", "\n", @@ -93,6 +93,10 @@ " Model = \"KoboldAI/fairseq-dense-13B-Shinen\"\n", " path = \"\"\n", " download = \"\"\n", + "elif Model == \"Skein 20B\":\n", + " Model = \"KoboldAI/GPT-NeoX-20B-Skein\"\n", + " path = \"\"\n", + " download = \"\"\n", "elif Model == \"NeoX 20B\":\n", " Model = \"EleutherAI/gpt-neox-20b\"\n", " path = \"\"\n", @@ -128,7 +132,7 @@ "elif Model == \"GPT-J-6B\":\n", " Model = \"EleutherAI/gpt-j-6B\"\n", " path = \"\"\n", - " download = \"\"\n", + " download = \"\"\n", "else:\n", " path = \"\"\n", " download = \"\"\n", @@ -225,4 +229,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file From 7eee21d674bc13b22dc05361d123f9bb32d02043 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 12:05:22 -0400 Subject: [PATCH 09/23] Fix for Colab download status bar --- utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils.py b/utils.py index 49ab9987..e441f5a7 100644 --- a/utils.py +++ b/utils.py @@ -279,6 +279,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d done = True break if bar is None: + print("setting up status bar for aria2 download") bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) visited = set() for x in r: From 046f9d8ace173a01146501d5b7f624f65fde1bcd Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 12:05:22 -0400 Subject: [PATCH 10/23] Fix for Colab download status bar --- utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils.py b/utils.py index 49ab9987..c8d564ff 100644 --- a/utils.py +++ b/utils.py @@ -181,6 +181,7 @@ class Send_to_socketio(object): print(bar, end="\r") emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) except: + raise pass def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): @@ -279,6 +280,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d done = True break if bar is None: + print("setting up status bar for aria2 download") bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) visited = set() for x in r: From 812ac8f27d42af8a7ad6ba94bdd89cdbddbb15dd Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 12:10:29 -0400 Subject: [PATCH 11/23] Debug --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index c8d564ff..100156e7 100644 --- a/utils.py +++ b/utils.py @@ -178,7 +178,7 @@ class Send_to_socketio(object): time.sleep(0.01) print("got bar data") try: - print(bar, end="\r") + print("Bar data: {}".format(bar)) emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) except: raise From 513f59791ad9781d13b0f8712b94ad8e113e6bf5 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 12:10:29 -0400 Subject: [PATCH 12/23] Debug --- utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils.py b/utils.py index c8d564ff..493ec382 100644 --- a/utils.py +++ b/utils.py @@ -178,10 +178,9 @@ class Send_to_socketio(object): time.sleep(0.01) print("got bar data") try: - print(bar, end="\r") + print("Bar data: {}".format(bar)) emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) except: - raise pass def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): From 8ba68e05ec610751c9699f347323d9ba51d8fb07 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 19 Aug 2022 12:13:46 -0400 Subject: [PATCH 13/23] Aria2 Status Bar Download Fix --- utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/utils.py b/utils.py index 493ec382..7fd82072 100644 --- a/utils.py +++ b/utils.py @@ -176,9 +176,8 @@ from flask_socketio import emit class Send_to_socketio(object): def write(self, bar): time.sleep(0.01) - print("got bar data") try: - print("Bar data: {}".format(bar)) + print(bar) emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True) except: pass @@ -279,7 +278,6 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d done = True break if bar is None: - print("setting up status bar for aria2 download") bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) visited = set() for x in r: From 55f45c49127d43a3cfffb13010f969cddbc0e664 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 22 Aug 2022 14:45:02 -0400 Subject: [PATCH 14/23] Fix the model selection GUI when there is no internet connection --- aiserver.py | 20 ++++++++++---------- utils.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/aiserver.py b/aiserver.py index ef785313..642ced7d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1474,22 +1474,22 @@ def get_model_info(model, directory=""): def get_layer_count(model, directory=""): if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]): - if(vars.model == "GPT2Custom"): - model_config = open(vars.custmodpth + "/config.json", "r") + if(model == "GPT2Custom"): + with open(os.path.join(directory, "config.json"), "r") as f: + model_config = json.load(f) # Get the model_type from the config or assume a model type if it isn't present else: + if(directory): + model = directory from transformers import AutoConfig - if directory == "": - model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache") + if(os.path.isdir(model.replace('/', '_'))): + model_config = AutoConfig.from_pretrained(model.replace('/', '_'), revision=vars.revision, cache_dir="cache") + elif(os.path.isdir("models/{}".format(model.replace('/', '_')))): + model_config = AutoConfig.from_pretrained("models/{}".format(model.replace('/', '_')), revision=vars.revision, cache_dir="cache") elif(os.path.isdir(directory)): model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache") - elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): - model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") else: - model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") - - - + model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache") return utils.num_layers(model_config) else: return None diff --git a/utils.py b/utils.py index 7fd82072..44c1129a 100644 --- a/utils.py +++ b/utils.py @@ -167,7 +167,7 @@ def decodenewlines(txt): # Returns number of layers given an HF model config #==================================================================# def num_layers(config): - return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None + return config["n_layer"] if isinstance(config, dict) else config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else None #==================================================================# # Downloads huggingface checkpoints using aria2c if possible From d7ebd2ae2050ee67cde8fa1a67d2042122a32e4e Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 22 Aug 2022 17:25:33 -0500 Subject: [PATCH 15/23] Dont broadcast token usage --- aiserver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index ef785313..1a53c344 100644 --- a/aiserver.py +++ b/aiserver.py @@ -3848,7 +3848,6 @@ def get_message(msg): emit( 'from_server', {'cmd': 'showfieldbudget', 'data': {"length": None, "max": None, "field": field}}, - broadcast=True ) return From 95796faf41d9e52ed3a9e3daf1ed89202767b014 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 22 Aug 2022 17:25:55 -0500 Subject: [PATCH 16/23] Add show budget setting --- gensettings.py | 11 +++++++++++ static/application.js | 23 ++++++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/gensettings.py b/gensettings.py index 3839ff4e..4824ed27 100644 --- a/gensettings.py +++ b/gensettings.py @@ -274,6 +274,17 @@ gensettingstf = [ "default": 0, "tooltip": "Shows token selection probabilities. Does not work with more than one gens per action." }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Show Field Budget", + "id": "setshowbudget", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Shows token usage when typing in relevant text boxes. May lag slower devices." + }, ] gensettingsik =[{ diff --git a/static/application.js b/static/application.js index 54757912..06c426b4 100644 --- a/static/application.js +++ b/static/application.js @@ -241,8 +241,27 @@ function addSetting(ob) { if(ob.id == "setadventure"){ setadventure($(this).prop('checked')); } + }); } + + if (ob.id === "setshowbudget") { + $("#setshowbudget").on("change", function () { + for (const el of document.getElementsByClassName("input-token-usage")) { + if (this.checked) { + el.classList.remove("hidden"); + } else { + el.classList.add("hidden"); + } + } + }); + + if (!$("#input-token-usage")[0].checked) { + for (const el of document.getElementsByClassName("input-token-usage")) { + el.classList.add("hidden"); + } + } + } } function refreshTitle() { @@ -2165,6 +2184,9 @@ function interpolateRGB(color0, color1, t) { } function updateInputBudget(inputElement) { + let budgetElement = document.getElementById("setshowbudget"); + if (budgetElement && !budgetElement.checked) return; + let data = {"unencoded": inputElement.value, "field": inputElement.id}; if (inputElement.id === "anoteinput") { @@ -2182,7 +2204,6 @@ function registerTokenCounters() { let span = document.createElement("span"); span.classList.add("input-token-usage"); - span.innerText = "?/? Tokens"; el.appendChild(span); let inputElement = el.querySelector("input, textarea"); From 9eecb61feaa14a35aedb5401b3b0d4b84052b58e Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 14:52:45 -0400 Subject: [PATCH 17/23] Remove unused import from warpers.py --- warpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warpers.py b/warpers.py index 2eac074e..488a901e 100644 --- a/warpers.py +++ b/warpers.py @@ -28,7 +28,7 @@ SOFTWARE. ''' import torch -from transformers import LogitsWarper, LogitsProcessor +from transformers import LogitsWarper class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): From 6ffaf43548b7a73b969b91eb5e16e0c6c86f6483 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:10:21 -0400 Subject: [PATCH 18/23] Repetition penalty is now sampler #6 in the sampler order --- aiserver.py | 20 +++++++++++++++----- tpu_mtj_backend.py | 7 +++++-- utils.py | 2 +- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/aiserver.py b/aiserver.py index 310067ad..6539fcb8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1806,7 +1806,10 @@ def patch_transformers(): self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): - for k in vars.sampler_order: + sampler_order = vars.sampler_order[:] + if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present + sampler_order = [6] + sampler_order + for k in sampler_order: scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) return scores @@ -1939,7 +1942,7 @@ def reset_model_settings(): vars.badwordsids = [] vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B) - vars.sampler_order = [0, 1, 2, 3, 4, 5] + vars.sampler_order = [6, 0, 1, 2, 3, 4, 5] vars.newlinemode = "n" vars.revision = None @@ -2550,8 +2553,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal vars.compiling = False def tpumtjgenerate_settings_callback() -> dict: + sampler_order = vars.sampler_order[:] + if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present + sampler_order = [6] + sampler_order return { - "sampler_order": vars.sampler_order, + "sampler_order": sampler_order, "top_p": float(vars.top_p), "temp": float(vars.temp), "top_k": int(vars.top_k), @@ -3658,12 +3664,16 @@ def get_message(msg): sendUSStatItems() elif(msg['cmd'] == 'samplers'): sampler_order = msg["data"] + sampler_order_min_length = 6 + sampler_order_max_length = 7 if(not isinstance(sampler_order, list)): raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}") - if(len(sampler_order) != len(vars.sampler_order)): - raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}") + if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)): + raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}") if(not all(isinstance(e, int) for e in sampler_order)): raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element") + if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)): + raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers") vars.sampler_order = sampler_order settingschanged() elif(msg['cmd'] == 'list_model'): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 1837fae6..19296e0a 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -312,10 +312,10 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) if k == 4 and typical < 1.0: logits = typical_filter(logits) if k == 5 and temp != 1.0: logits = temp_filter(logits) + if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - logits = apply_repetition_penalty_dynamic(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(np.uint32) def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): @@ -498,10 +498,10 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - logits = apply_repetition_penalty_static(logits, *rpargs) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -858,6 +858,9 @@ def infer_static( maps.thread_resources.env = thread_resources_env if sampler_order is None: sampler_order = utils.default_sampler_order.copy() + sampler_order = sampler_order[:] + if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present + sampler_order = [6] + sampler_order sampler_order = np.uint32(sampler_order) total_batch = 1 tokens = context diff --git a/utils.py b/utils.py index 7fd82072..76c04ea2 100644 --- a/utils.py +++ b/utils.py @@ -33,7 +33,7 @@ layers_module_names: Optional[List[str]] = None module_names: Optional[List[str]] = None named_buffers: Optional[List[tuple]] = None -default_sampler_order = [0, 1, 2, 3, 4, 5] +default_sampler_order = [6, 0, 1, 2, 3, 4, 5] #==================================================================# # Decorator to prevent a function's actions from being run until From aee4beb27a58c9f0dfb024e13f20da39fc7b9a48 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:26:15 -0400 Subject: [PATCH 19/23] Fix the Show Field Budget toggle --- static/application.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/static/application.js b/static/application.js index 06c426b4..25564cdf 100644 --- a/static/application.js +++ b/static/application.js @@ -256,7 +256,7 @@ function addSetting(ob) { } }); - if (!$("#input-token-usage")[0].checked) { + if (!$("#setshowbudget")[0].checked) { for (const el of document.getElementsByClassName("input-token-usage")) { el.classList.add("hidden"); } From cbfe456409a82872396e12941f920e30ff708720 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:30:07 -0400 Subject: [PATCH 20/23] Repetition penalty is now added to sampler list when loading from settings files --- aiserver.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6539fcb8..2a26bc5e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -963,7 +963,10 @@ def loadmodelsettings(): if("nobreakmodel" in js): vars.nobreakmodel = js["nobreakmodel"] if("sampler_order" in js): - vars.sampler_order = js["sampler_order"] + sampler_order = vars.sampler_order + if(len(sampler_order) < 7): + sampler_order = [6] + sampler_order + vars.sampler_order = sampler_order if("temp" in js): vars.temp = js["temp"] if("top_p" in js): @@ -1094,7 +1097,10 @@ def processsettings(js): if("andepth" in js): vars.andepth = js["andepth"] if("sampler_order" in js): - vars.sampler_order = js["sampler_order"] + sampler_order = vars.sampler_order + if(len(sampler_order) < 7): + sampler_order = [6] + sampler_order + vars.sampler_order = sampler_order if("temp" in js): vars.temp = js["temp"] if("top_p" in js): From ff9058896ebc7dd71781b022fe994a37de652aa4 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 15:42:23 -0400 Subject: [PATCH 21/23] Add Repetition Penalty to Samplers menu --- static/application.js | 3 ++- static/custom.css | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/static/application.js b/static/application.js index 25564cdf..9107e161 100644 --- a/static/application.js +++ b/static/application.js @@ -1306,12 +1306,13 @@ function buildSamplerList(samplers) { "Tail-free Sampling", "Typical Sampling", "Temperature", + "Repetition Penalty", ] for(i=0; i\
\
\ -
"+samplers_lookup_table[samplers[i]]+"
\ +
"+(samplers[i] < samplers_lookup_table.length ? samplers_lookup_table[samplers[i]] : "Unknown sampler #" + samplers[i])+"
\
\
\ "); diff --git a/static/custom.css b/static/custom.css index af238dc7..d4bfe872 100644 --- a/static/custom.css +++ b/static/custom.css @@ -473,7 +473,7 @@ body.connected #popupfooter, #popupfooter.always-available { } #samplerslist { - height: 300px; + height: 310px; overflow-y: scroll; overflow-wrap: anywhere; } From 938e1eddf32a67fc1af127afd6eb7cfbdabde2e8 Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 23 Aug 2022 18:13:46 -0400 Subject: [PATCH 22/23] Fix `jax.lax.cond` call --- tpu_mtj_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 19296e0a..effb3de0 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -498,7 +498,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs)) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) From cbacfbdfac372d42fd5138783e1add3b84586e89 Mon Sep 17 00:00:00 2001 From: vfbd Date: Sat, 27 Aug 2022 17:42:49 -0400 Subject: [PATCH 23/23] Fix error that occurs when using dynamic TPU backend --- tpu_mtj_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index effb3de0..29ac4b42 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -533,7 +533,7 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ gen_length, rpslope, rprange, - ) + ), **sampler_options, ) # Remember what token was picked