From fdb2a7fa4ce87a586aae1590ded4630a1a03e48d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 10 Jun 2022 22:28:20 -0400 Subject: [PATCH 1/7] Top-A sampling --- aiserver.py | 23 ++++++++++++++++++++++- bridge.lua | 2 ++ gensettings.py | 11 +++++++++++ static/application.js | 7 +++++++ templates/index.html | 2 +- tpu_mtj_backend.py | 42 ++++++++++++++++++++++++++++++++++++------ warpers.py | 29 +++++++++++++++++++++++++++++ 7 files changed, 108 insertions(+), 8 deletions(-) diff --git a/aiserver.py b/aiserver.py index fefab9b8..f14bbd77 100644 --- a/aiserver.py +++ b/aiserver.py @@ -212,6 +212,7 @@ class vars: temp = 0.5 # Default generator temperature top_p = 0.9 # Default generator top_p top_k = 0 # Default generator top_k + top_a = 0.0 # Default generator top-a tfs = 1.0 # Default generator tfs (tail-free sampling) typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create @@ -577,6 +578,8 @@ def loadmodelsettings(): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -613,6 +616,7 @@ def savesettings(): js["top_k"] = vars.top_k js["tfs"] = vars.tfs js["typical"] = vars.typical + js["top_a"] = vars.top_a js["rep_pen"] = vars.rep_pen js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_range"] = vars.rep_pen_range @@ -693,6 +697,8 @@ def processsettings(js): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -1379,7 +1385,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ @@ -1399,6 +1405,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go cls.__call__ = new_call dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) @@ -1445,6 +1452,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: warper_list = LogitsProcessorList() warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) @@ -1814,6 +1822,7 @@ else: "top_k": int(vars.top_k), "tfs": float(vars.tfs), "typical": float(vars.typical), + "top_a": float(vars.top_a), "repetition_penalty": float(vars.rep_pen), "rpslope": float(vars.rep_pen_slope), "rprange": int(vars.rep_pen_range), @@ -2176,6 +2185,7 @@ def lua_has_setting(setting): "settopk", "settfs", "settypical", + "settopa", "setreppen", "setreppenslope", "setreppenrange", @@ -2195,6 +2205,7 @@ def lua_has_setting(setting): "top_k", "tfs", "typical", + "topa", "reppen", "reppenslope", "reppenrange", @@ -2229,6 +2240,7 @@ def lua_get_setting(setting): if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("settypical", "typical")): return vars.typical + if(setting in ("settopa", "topa")): return vars.top_a if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range @@ -2264,6 +2276,7 @@ def lua_set_setting(setting, v): if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("settypical", "typical")): vars.typical = v + if(setting in ("settopa", "topa")): vars.top_a = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v @@ -2688,6 +2701,11 @@ def get_message(msg): emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True) settingschanged() refresh_settings() + elif(msg['cmd'] == 'settopa'): + vars.top_a = float(msg['data']) + emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) @@ -3748,6 +3766,7 @@ def sendtocolab(txt, min, max): 'top_k': vars.top_k, 'tfs': vars.tfs, 'typical': vars.typical, + 'topa': vars.top_a, 'numseqs': vars.numseqs, 'retfultxt': False } @@ -3885,6 +3904,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): top_k=vars.top_k, tfs=vars.tfs, typical=vars.typical, + top_a=vars.top_a, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, rpslope=vars.rep_pen_slope, @@ -4071,6 +4091,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True) + emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) diff --git a/bridge.lua b/bridge.lua index ed0941c6..fc6c8823 100644 --- a/bridge.lua +++ b/bridge.lua @@ -867,6 +867,7 @@ return function(_python, _bridged) ---@field settopk integer ---@field settfs number ---@field settypical number + ---@field settopa number ---@field setreppen number ---@field setreppenslope number ---@field setreppenrange number @@ -884,6 +885,7 @@ return function(_python, _bridged) ---@field top_k integer ---@field tfs number ---@field typical number + ---@field topa number ---@field reppen number ---@field reppenslope number ---@field reppenrange number diff --git a/gensettings.py b/gensettings.py index e8d4e566..b3007c91 100644 --- a/gensettings.py +++ b/gensettings.py @@ -64,6 +64,17 @@ gensettingstf = [ "step": 0.05, "default": 1.0, "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Top a Sampling", + "id": "settopa", + "min": 0.0, + "max": 1.0, + "step": 0.01, + "default": 0.0, + "tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect." }, { "uitype": "slider", diff --git a/static/application.js b/static/application.js index 3b50281c..55487f76 100644 --- a/static/application.js +++ b/static/application.js @@ -2096,6 +2096,10 @@ $(document).ready(function(){ // Send current typical value to input $("#settypicalcur").val(msg.data); $("#settypical").val(parseFloat(msg.data)).trigger("change"); + } else if(msg.cmd == "updatetopa") { + // Send current top a value to input + $("#settopacur").val(msg.data); + $("#settopa").val(parseFloat(msg.data)).trigger("change"); } else if(msg.cmd == "updatereppen") { // Send current rep pen value to input $("#setreppencur").val(msg.data); @@ -2135,6 +2139,9 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabeltypical") { // Update setting label with value from server $("#settypicalcur").val(msg.data); + } else if(msg.cmd == "setlabeltypical") { + // Update setting label with value from server + $("#settopa").val(msg.data); } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").val(msg.data); diff --git a/templates/index.html b/templates/index.html index a3214ffa..690535f7 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index fb2dc7ae..f66ad53c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -70,6 +70,7 @@ def settings_callback() -> dict: "top_k": 0, "tfs": 1.0, "typical": 1.0, + "top_a": 0.0, "repetition_penalty": 1.0, "rpslope": 0.0, "rprange": 0, @@ -158,10 +159,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): +def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' - This gets called by generate_loop_fn to apply a series of 5 filters - to the logits (top-k, then top-p, then TFS, then typical, then temperature) + This gets called by generate_loop_fn to apply a series of 6 filters + to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove @@ -182,6 +183,20 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty return np.where(indices_to_remove, -np.inf, logits) if top_k > 0: logits = top_k_filter(logits) + # Top-a (remove all tokens that have softmax probability less than + # a*m^2 where m is the maximum softmax probability) + def top_a_filter(logits): + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = np.array(jax.nn.softmax(logits), copy=True) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) + if top_a > 0.0: + logits = top_a_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -332,10 +347,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): +def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' - This gets called by generate_loop_fn to apply a series of 5 filters - to the logits (top-k, then top-p, then TFS, then typical, then temperature) + This gets called by generate_loop_fn to apply a series of 6 filters + to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' # Top-k (keep only the k tokens with the highest logits and remove @@ -355,6 +370,19 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ ) return jnp.where(indices_to_remove, -jnp.inf, logits) logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) + # Top-a (remove all tokens that have softmax probability less than + # a*m^2 where m is the maximum softmax probability) + def top_a_filter(logits): + # Replace every element in the logits array + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + probabilities = jax.nn.softmax(logits) + # Find the largest probability + probs_max = probabilities.max() + # Remove tokens + return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) + logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -806,6 +834,7 @@ def infer_static( top_k=0, tfs=1.0, typical=1.0, + top_a=0.0, repetition_penalty=1.0, rpslope=0.0, rprange=0, @@ -829,6 +858,7 @@ def infer_static( "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), "typical": typical * np.ones(total_batch), + "top_a": top_a * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch), "rpslope": rpslope * np.ones(total_batch), "rprange": np.full(total_batch, rprange, dtype=np.uint32), diff --git a/warpers.py b/warpers.py index 7c4f854b..9c1f88eb 100644 --- a/warpers.py +++ b/warpers.py @@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper): indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores + + +class TopALogitsWarper(LogitsWarper): + def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + top_a = float(top_a) + if top_a < 0 or top_a > 1.0: + raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}") + self.top_a = top_a + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.filter_value >= 1.0: + return scores + + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) + probs_max = probs[..., 0, None] + sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores From 42b7a327b2440a3a48a969059beb06293a26e239 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 10 Jun 2022 22:34:14 -0400 Subject: [PATCH 2/7] Fix an unfortunate typo in top-a warper --- warpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warpers.py b/warpers.py index 9c1f88eb..fb683f50 100644 --- a/warpers.py +++ b/warpers.py @@ -168,7 +168,7 @@ class TopALogitsWarper(LogitsWarper): # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) probs_max = probs[..., 0, None] - sorted_indices_to_remove = probs >= probs_max * probs_max * self.top_a + sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep From 5c81374a48822cdaeb742d9aa007a152db30260d Mon Sep 17 00:00:00 2001 From: Henk Date: Sat, 11 Jun 2022 22:04:37 +0200 Subject: [PATCH 3/7] Top A for GooseAi --- aiserver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiserver.py b/aiserver.py index f14bbd77..b51fe538 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4668,6 +4668,7 @@ def oairequest(txt, min, max): 'prompt': txt, 'max_tokens': vars.genamt, 'temperature': vars.temp, + 'top_a': vars.top_a, 'top_p': vars.top_p, 'top_k': vars.top_k, 'tfs': vars.tfs, From 66c0dda485d9be7df300e7e791d5b625f0c825e3 Mon Sep 17 00:00:00 2001 From: Henk Date: Sat, 11 Jun 2022 22:54:51 +0200 Subject: [PATCH 4/7] Hide (Broken) Chatbot Models Removing this option because they are currently unavailable. People who still have them can load them trough the load from file option. Once they have been retrained and reuploaded I will add the menu back. --- aiserver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index b51fe538..6267aec2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -90,7 +90,6 @@ mainmenu = [ ["Adventure Models", "adventurelist", ""], ["Novel Models", "novellist", ""], ["NSFW Models", "nsfwlist", ""], - ["Chatbot Models", "chatlist", ""], ["Untuned GPT-Neo/J", "gptneolist", ""], ["Untuned Fairseq Dense", "fsdlist", ""], ["Untuned OPT", "optlist", ""], From 2d3db7b4ba388f566aaec88a0e76678fe4fade8d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 19:12:23 -0400 Subject: [PATCH 5/7] Implement support for sampler order in the backend code --- aiserver.py | 27 +++++++++++++++++++-------- tpu_mtj_backend.py | 46 +++++++++++++++++++++++++++------------------- utils.py | 2 ++ 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6267aec2..0bed5ad8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -306,6 +306,7 @@ class vars: acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor + sampler_order = utils.default_sampler_order.copy() chatmode = False chatname = "You" adventure = False @@ -1448,15 +1449,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor + class KoboldLogitsWarperList(LogitsProcessorList): + def __init__(self, beams: int = 1, **kwargs): + self.__warper_list: List[LogitsWarper] = [] + self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): + for k in vars.sampler_order: + scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) + return scores + def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: - warper_list = LogitsProcessorList() - warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TemperatureLogitsWarper(temperature=0.5)) - return warper_list + return KoboldLogitsWarperList(beams=beams) def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None @@ -1816,6 +1825,7 @@ else: def tpumtjgenerate_settings_callback() -> dict: return { + "sampler_order": vars.sampler_order, "top_p": float(vars.top_p), "temp": float(vars.temp), "top_k": int(vars.top_k), @@ -3910,6 +3920,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): rprange=vars.rep_pen_range, soft_embeddings=vars.sp, soft_tokens=soft_tokens, + sampler_order=vars.sampler_order, ) past = genout for i in range(vars.numseqs): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index f66ad53c..67e006d6 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -65,6 +65,7 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List def settings_callback() -> dict: return { + "sampler_order": utils.default_sampler_order.copy(), "top_p": 0.9, "temp": 0.5, "top_k": 0, @@ -159,7 +160,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -181,8 +182,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if top_k > 0: - logits = top_k_filter(logits) # Top-a (remove all tokens that have softmax probability less than # a*m^2 where m is the maximum softmax probability) def top_a_filter(logits): @@ -195,8 +194,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty probs_max = probabilities.max() # Remove tokens return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) - if top_a > 0.0: - logits = top_a_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -222,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if top_p < 1.0: - logits = top_p_filter(logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the @@ -262,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if tfs < 1.0: - logits = tail_free_filter(logits) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them @@ -293,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -jnp.inf, logits) - if typical < 1.0: - logits = typical_filter(logits) # Temperature (just divide the logits by the temperature) - logits /= temp + def temp_filter(logits): + return logits / temp + for k in sampler_order: + if k == 0 and top_k > 0: logits = top_k_filter(logits) + if k == 1 and top_a > 0.0: logits = top_a_filter(logits) + if k == 2 and top_p < 1.0: logits = top_p_filter(logits) + if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) + if k == 4 and typical < 1.0: logits = typical_filter(logits) + if k == 5 and temp != 1.0: logits = temp_filter(logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -347,7 +346,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -369,7 +368,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) # Top-a (remove all tokens that have softmax probability less than # a*m^2 where m is the maximum softmax probability) def top_a_filter(logits): @@ -382,7 +380,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ probs_max = probabilities.max() # Remove tokens return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) - logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -408,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the @@ -447,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them @@ -476,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits) # Temperature (just divide the logits by the temperature) def temp_filter(logits): return logits / temp - logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) + for k in sampler_order: + logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -842,8 +842,12 @@ def infer_static( gen_len=80, soft_embeddings: Optional[np.array] = None, soft_tokens: Optional[np.array] = None, + sampler_order: Optional[List[int]] = None, ) -> List[np.array]: maps.thread_resources.env = thread_resources_env + if sampler_order is None: + sampler_order = utils.default_sampler_order.copy() + sampler_order = np.uint32(sampler_order) total_batch = 1 tokens = context if(soft_tokens is not None): @@ -854,6 +858,7 @@ def infer_static( batched_tokens = np.array([padded_tokens] * total_batch) samples = [] batched_generator_params = { + "sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0), "temp": temp * np.ones(total_batch), "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), @@ -1015,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params + if not hasattr(vars, "sampler_order") or not vars.sampler_order: + vars.sampler_order = utils.default_sampler_order.copy() + default_params = { "compat": "j", "layers": 28, diff --git a/utils.py b/utils.py index bc085412..96606269 100644 --- a/utils.py +++ b/utils.py @@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None from_pretrained_kwargs = {} bar = None +default_sampler_order = [0, 1, 2, 3, 4, 5] + #==================================================================# # Decorator to prevent a function's actions from being run until # at least x seconds have passed without the function being called From 4c7d6f42d99d557130511f5d185249b34f9db5a1 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 19:14:38 -0400 Subject: [PATCH 6/7] Add `sampler_order` to settings file --- aiserver.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aiserver.py b/aiserver.py index 0bed5ad8..abaffa77 100644 --- a/aiserver.py +++ b/aiserver.py @@ -568,6 +568,8 @@ def loadmodelsettings(): vars.badwordsids = js["badwordsids"] if("nobreakmodel" in js): vars.nobreakmodel = js["nobreakmodel"] + if("sampler_order" in js): + vars.sampler_order = js["sampler_order"] if("temp" in js): vars.temp = js["temp"] if("top_p" in js): @@ -611,6 +613,7 @@ def savesettings(): js = {} js["apikey"] = vars.apikey js["andepth"] = vars.andepth + js["sampler_order"] = vars.sampler_order js["temp"] = vars.temp js["top_p"] = vars.top_p js["top_k"] = vars.top_k @@ -687,6 +690,8 @@ def processsettings(js): vars.apikey = js["apikey"] if("andepth" in js): vars.andepth = js["andepth"] + if("sampler_order" in js): + vars.sampler_order = js["sampler_order"] if("temp" in js): vars.temp = js["temp"] if("top_p" in js): From 6231106f95221bdfa3ed452fdca0bb14b22aa453 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 20:18:09 -0400 Subject: [PATCH 7/7] Add Samplers menu --- aiserver.py | 12 ++++++ static/application.js | 90 ++++++++++++++++++++++++++++++++++++++++++- static/custom.css | 28 ++++++++++++-- templates/index.html | 20 +++++++++- 4 files changed, 143 insertions(+), 7 deletions(-) diff --git a/aiserver.py b/aiserver.py index abaffa77..06c65fc0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2873,6 +2873,8 @@ def get_message(msg): elif(msg['cmd'] == 'uslistrequest'): unloaded, loaded = getuslist() emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}}) + elif(msg['cmd'] == 'samplerlistrequest'): + emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order}) elif(msg['cmd'] == 'usloaded'): vars.userscripts = [] for userscript in msg['data']: @@ -2886,6 +2888,16 @@ def get_message(msg): load_lua_scripts() unloaded, loaded = getuslist() sendUSStatItems() + elif(msg['cmd'] == 'samplers'): + sampler_order = msg["data"] + if(not isinstance(sampler_order, list)): + raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}") + if(len(sampler_order) != len(vars.sampler_order)): + raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}") + if(not all(isinstance(e, int) for e in sampler_order)): + raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element") + vars.sampler_order = sampler_order + settingschanged() elif(msg['cmd'] == 'loadselect'): vars.loadselect = msg["data"] elif(msg['cmd'] == 'spselect'): diff --git a/static/application.js b/static/application.js index 55487f76..3cddea87 100644 --- a/static/application.js +++ b/static/application.js @@ -20,6 +20,7 @@ var button_settings; var button_format; var button_softprompt; var button_userscripts; +var button_samplers; var button_mode; var button_mode_label; var button_send; @@ -109,6 +110,9 @@ var do_clear_ent = false; // Whether or not an entry in the Userscripts menu is being dragged var us_dragging = false; +// Whether or not an entry in the Samplers menu is being dragged +var samplers_dragging = false; + // Display vars var allowtoggle = false; var formatcount = 0; @@ -976,6 +980,16 @@ function hideUSPopup() { spcontent.html(""); } +function showSamplersPopup() { + samplerspopup.removeClass("hidden"); + samplerspopup.addClass("flex"); +} + +function hideSamplersPopup() { + samplerspopup.removeClass("flex"); + samplerspopup.addClass("hidden"); +} + function buildLoadList(ar) { disableButtons([load_accept]); loadcontent.html(""); @@ -1109,6 +1123,29 @@ function buildUSList(unloaded, loaded) { } } +function buildSamplerList(samplers) { + samplerslist.html(""); + showSamplersPopup(); + var i; + var samplers_lookup_table = [ + "Top-k Sampling", + "Top-a Sampling", + "Top-p Sampling", + "Tail-free Sampling", + "Typical Sampling", + "Temperature", + ] + for(i=0; i\ +
\ +
\ +
"+samplers_lookup_table[samplers[i]]+"
\ +
\ +
\ + "); + } +} + function highlightLoadLine(ref) { $("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected"); ref.addClass("popuplistselected"); @@ -1838,6 +1875,7 @@ $(document).ready(function(){ button_format = $('#btn_format'); button_softprompt = $("#btn_softprompt"); button_userscripts= $("#btn_userscripts"); + button_samplers = $("#btn_samplers"); button_mode = $('#btnmode') button_mode_label = $('#btnmode_label') button_send = $('#btnsend'); @@ -1886,6 +1924,10 @@ $(document).ready(function(){ usloaded = $("#uslistloaded"); us_accept = $("#btn_usaccept"); us_close = $("#btn_usclose"); + samplerspopup = $("#samplerscontainer"); + samplerslist = $("#samplerslist"); + samplers_accept = $("#btn_samplersaccept"); + samplers_close = $("#btn_samplersclose"); nspopup = $("#newgamecontainer"); ns_accept = $("#btn_nsaccept"); ns_close = $("#btn_nsclose"); @@ -1908,7 +1950,7 @@ $(document).ready(function(){ modelname = msg.modelname; } refreshTitle(); - connect_status.html("Connected to KoboldAI Process!"); + connect_status.html("Connected to KoboldAI!"); connect_status.removeClass("color_orange"); connect_status.addClass("color_green"); // Reset Menus @@ -2310,6 +2352,8 @@ $(document).ready(function(){ buildSPList(msg.data); } else if(msg.cmd == "buildus") { buildUSList(msg.data.unloaded, msg.data.loaded); + } else if(msg.cmd == "buildsamplers") { + buildSamplerList(msg.data); } else if(msg.cmd == "askforoverwrite") { // Show overwrite warning show([$(".saveasoverwrite")]); @@ -2436,6 +2480,20 @@ $(document).ready(function(){ }, 10); } + var samplers_click_handler = function(ev) { + setTimeout(function() { + if (samplers_dragging) { + return; + } + var target = $(ev.target).closest(".samplerslistitem"); + var next = target.parent().next().find(".samplerslistitem"); + if (!next.length) { + return; + } + next.parent().after(target.parent()); + }, 10); + } + // Make the userscripts menu sortable var us_sortable_settings = { placeholder: "ussortable-placeholder", @@ -2456,6 +2514,22 @@ $(document).ready(function(){ connectWith: "#uslistunloaded", }, us_sortable_settings)).on("click", ".uslistitem", us_click_handler); + // Make the samplers menu sortable + var samplers_sortable_settings = { + placeholder: "samplerssortable-placeholder", + start: function() { samplers_dragging = true; }, + stop: function() { samplers_dragging = false; }, + delay: 2, + cursor: "move", + tolerance: "pointer", + opacity: 0.21, + revert: 173, + scrollSensitivity: 64, + scrollSpeed: 10, + } + samplerslist.sortable($.extend({ + }, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler); + // Bind actions to UI buttons button_send.on("click", function(ev) { dosubmit(); @@ -2590,6 +2664,10 @@ $(document).ready(function(){ button_userscripts.on("click", function(ev) { socket.send({'cmd': 'uslistrequest', 'data': ''}); }); + + button_samplers.on("click", function(ev) { + socket.send({'cmd': 'samplerlistrequest', 'data': ''}); + }); load_close.on("click", function(ev) { hideLoadPopup(); @@ -2623,6 +2701,16 @@ $(document).ready(function(){ socket.send({'cmd': 'usload', 'data': ''}); hideUSPopup(); }); + + samplers_close.on("click", function(ev) { + hideSamplersPopup(); + }); + + samplers_accept.on("click", function(ev) { + hideMessage(); + socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()}); + hideSamplersPopup(); + }); button_newgame.on("click", function(ev) { if(connected) { diff --git a/static/custom.css b/static/custom.css index d70fd34e..640cb8db 100644 --- a/static/custom.css +++ b/static/custom.css @@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available { overflow-wrap: anywhere; } +#samplerspopup { + width: 300px; + background-color: #262626; + margin-top: 100px; +} + +@media (max-width: 768px) { + #samplerspopup { + width: 100%; + background-color: #262626; + margin-top: 100px; + } +} + +#samplerslist { + height: 300px; + overflow-y: scroll; + overflow-wrap: anywhere; +} + #nspopup { width: 350px; background-color: #262626; @@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover { background-color: #3bf723; } -.ussortable-placeholder { +.ussortable-placeholder, .samplerssortable-placeholder { height: 4px; background-color: #3bf723; } @@ -1340,7 +1360,7 @@ body.connected .popupfooter, .popupfooter.always-available { background-color: #688f1f; } -.uslistitem { +.uslistitem, .samplerslistitem { padding: 12px 10px 12px 10px; display: flex; flex-grow: 1; @@ -1352,11 +1372,11 @@ body.connected .popupfooter, .popupfooter.always-available { transition: background-color 0.25s ease-in; } -.uslistitemsub { +.uslistitemsub, .samplerslistitemsub { color: #ba9; } -.uslistitem:hover { +.uslistitem:hover, .samplerslistitem:hover { cursor: move; background-color: #688f1f; } diff --git a/templates/index.html b/templates/index.html index 690535f7..7ec9f66c 100644 --- a/templates/index.html +++ b/templates/index.html @@ -9,7 +9,7 @@ - + @@ -17,7 +17,7 @@ - + @@ -71,6 +71,9 @@ + @@ -299,6 +302,19 @@ +