From 2d3db7b4ba388f566aaec88a0e76678fe4fade8d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 19:12:23 -0400 Subject: [PATCH 1/3] Implement support for sampler order in the backend code --- aiserver.py | 27 +++++++++++++++++++-------- tpu_mtj_backend.py | 46 +++++++++++++++++++++++++++------------------- utils.py | 2 ++ 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6267aec2..0bed5ad8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -306,6 +306,7 @@ class vars: acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor + sampler_order = utils.default_sampler_order.copy() chatmode = False chatname = "You" adventure = False @@ -1448,15 +1449,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor + class KoboldLogitsWarperList(LogitsProcessorList): + def __init__(self, beams: int = 1, **kwargs): + self.__warper_list: List[LogitsWarper] = [] + self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): + for k in vars.sampler_order: + scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) + return scores + def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: - warper_list = LogitsProcessorList() - warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TemperatureLogitsWarper(temperature=0.5)) - return warper_list + return KoboldLogitsWarperList(beams=beams) def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None @@ -1816,6 +1825,7 @@ else: def tpumtjgenerate_settings_callback() -> dict: return { + "sampler_order": vars.sampler_order, "top_p": float(vars.top_p), "temp": float(vars.temp), "top_k": int(vars.top_k), @@ -3910,6 +3920,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): rprange=vars.rep_pen_range, soft_embeddings=vars.sp, soft_tokens=soft_tokens, + sampler_order=vars.sampler_order, ) past = genout for i in range(vars.numseqs): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index f66ad53c..67e006d6 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -65,6 +65,7 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List def settings_callback() -> dict: return { + "sampler_order": utils.default_sampler_order.copy(), "top_p": 0.9, "temp": 0.5, "top_k": 0, @@ -159,7 +160,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -181,8 +182,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if top_k > 0: - logits = top_k_filter(logits) # Top-a (remove all tokens that have softmax probability less than # a*m^2 where m is the maximum softmax probability) def top_a_filter(logits): @@ -195,8 +194,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty probs_max = probabilities.max() # Remove tokens return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) - if top_a > 0.0: - logits = top_a_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -222,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if top_p < 1.0: - logits = top_p_filter(logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the @@ -262,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if tfs < 1.0: - logits = tail_free_filter(logits) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them @@ -293,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -jnp.inf, logits) - if typical < 1.0: - logits = typical_filter(logits) # Temperature (just divide the logits by the temperature) - logits /= temp + def temp_filter(logits): + return logits / temp + for k in sampler_order: + if k == 0 and top_k > 0: logits = top_k_filter(logits) + if k == 1 and top_a > 0.0: logits = top_a_filter(logits) + if k == 2 and top_p < 1.0: logits = top_p_filter(logits) + if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) + if k == 4 and typical < 1.0: logits = typical_filter(logits) + if k == 5 and temp != 1.0: logits = temp_filter(logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -347,7 +346,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -369,7 +368,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) # Top-a (remove all tokens that have softmax probability less than # a*m^2 where m is the maximum softmax probability) def top_a_filter(logits): @@ -382,7 +380,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ probs_max = probabilities.max() # Remove tokens return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) - logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -408,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the @@ -447,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them @@ -476,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits) # Temperature (just divide the logits by the temperature) def temp_filter(logits): return logits / temp - logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) + for k in sampler_order: + logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -842,8 +842,12 @@ def infer_static( gen_len=80, soft_embeddings: Optional[np.array] = None, soft_tokens: Optional[np.array] = None, + sampler_order: Optional[List[int]] = None, ) -> List[np.array]: maps.thread_resources.env = thread_resources_env + if sampler_order is None: + sampler_order = utils.default_sampler_order.copy() + sampler_order = np.uint32(sampler_order) total_batch = 1 tokens = context if(soft_tokens is not None): @@ -854,6 +858,7 @@ def infer_static( batched_tokens = np.array([padded_tokens] * total_batch) samples = [] batched_generator_params = { + "sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0), "temp": temp * np.ones(total_batch), "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), @@ -1015,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params + if not hasattr(vars, "sampler_order") or not vars.sampler_order: + vars.sampler_order = utils.default_sampler_order.copy() + default_params = { "compat": "j", "layers": 28, diff --git a/utils.py b/utils.py index bc085412..96606269 100644 --- a/utils.py +++ b/utils.py @@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None from_pretrained_kwargs = {} bar = None +default_sampler_order = [0, 1, 2, 3, 4, 5] + #==================================================================# # Decorator to prevent a function's actions from being run until # at least x seconds have passed without the function being called From 4c7d6f42d99d557130511f5d185249b34f9db5a1 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 19:14:38 -0400 Subject: [PATCH 2/3] Add `sampler_order` to settings file --- aiserver.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aiserver.py b/aiserver.py index 0bed5ad8..abaffa77 100644 --- a/aiserver.py +++ b/aiserver.py @@ -568,6 +568,8 @@ def loadmodelsettings(): vars.badwordsids = js["badwordsids"] if("nobreakmodel" in js): vars.nobreakmodel = js["nobreakmodel"] + if("sampler_order" in js): + vars.sampler_order = js["sampler_order"] if("temp" in js): vars.temp = js["temp"] if("top_p" in js): @@ -611,6 +613,7 @@ def savesettings(): js = {} js["apikey"] = vars.apikey js["andepth"] = vars.andepth + js["sampler_order"] = vars.sampler_order js["temp"] = vars.temp js["top_p"] = vars.top_p js["top_k"] = vars.top_k @@ -687,6 +690,8 @@ def processsettings(js): vars.apikey = js["apikey"] if("andepth" in js): vars.andepth = js["andepth"] + if("sampler_order" in js): + vars.sampler_order = js["sampler_order"] if("temp" in js): vars.temp = js["temp"] if("top_p" in js): From 6231106f95221bdfa3ed452fdca0bb14b22aa453 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 20:18:09 -0400 Subject: [PATCH 3/3] Add Samplers menu --- aiserver.py | 12 ++++++ static/application.js | 90 ++++++++++++++++++++++++++++++++++++++++++- static/custom.css | 28 ++++++++++++-- templates/index.html | 20 +++++++++- 4 files changed, 143 insertions(+), 7 deletions(-) diff --git a/aiserver.py b/aiserver.py index abaffa77..06c65fc0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2873,6 +2873,8 @@ def get_message(msg): elif(msg['cmd'] == 'uslistrequest'): unloaded, loaded = getuslist() emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}}) + elif(msg['cmd'] == 'samplerlistrequest'): + emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order}) elif(msg['cmd'] == 'usloaded'): vars.userscripts = [] for userscript in msg['data']: @@ -2886,6 +2888,16 @@ def get_message(msg): load_lua_scripts() unloaded, loaded = getuslist() sendUSStatItems() + elif(msg['cmd'] == 'samplers'): + sampler_order = msg["data"] + if(not isinstance(sampler_order, list)): + raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}") + if(len(sampler_order) != len(vars.sampler_order)): + raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}") + if(not all(isinstance(e, int) for e in sampler_order)): + raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element") + vars.sampler_order = sampler_order + settingschanged() elif(msg['cmd'] == 'loadselect'): vars.loadselect = msg["data"] elif(msg['cmd'] == 'spselect'): diff --git a/static/application.js b/static/application.js index 55487f76..3cddea87 100644 --- a/static/application.js +++ b/static/application.js @@ -20,6 +20,7 @@ var button_settings; var button_format; var button_softprompt; var button_userscripts; +var button_samplers; var button_mode; var button_mode_label; var button_send; @@ -109,6 +110,9 @@ var do_clear_ent = false; // Whether or not an entry in the Userscripts menu is being dragged var us_dragging = false; +// Whether or not an entry in the Samplers menu is being dragged +var samplers_dragging = false; + // Display vars var allowtoggle = false; var formatcount = 0; @@ -976,6 +980,16 @@ function hideUSPopup() { spcontent.html(""); } +function showSamplersPopup() { + samplerspopup.removeClass("hidden"); + samplerspopup.addClass("flex"); +} + +function hideSamplersPopup() { + samplerspopup.removeClass("flex"); + samplerspopup.addClass("hidden"); +} + function buildLoadList(ar) { disableButtons([load_accept]); loadcontent.html(""); @@ -1109,6 +1123,29 @@ function buildUSList(unloaded, loaded) { } } +function buildSamplerList(samplers) { + samplerslist.html(""); + showSamplersPopup(); + var i; + var samplers_lookup_table = [ + "Top-k Sampling", + "Top-a Sampling", + "Top-p Sampling", + "Tail-free Sampling", + "Typical Sampling", + "Temperature", + ] + for(i=0; i\ +
\ +
\ +
"+samplers_lookup_table[samplers[i]]+"
\ +
\ +
\ + "); + } +} + function highlightLoadLine(ref) { $("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected"); ref.addClass("popuplistselected"); @@ -1838,6 +1875,7 @@ $(document).ready(function(){ button_format = $('#btn_format'); button_softprompt = $("#btn_softprompt"); button_userscripts= $("#btn_userscripts"); + button_samplers = $("#btn_samplers"); button_mode = $('#btnmode') button_mode_label = $('#btnmode_label') button_send = $('#btnsend'); @@ -1886,6 +1924,10 @@ $(document).ready(function(){ usloaded = $("#uslistloaded"); us_accept = $("#btn_usaccept"); us_close = $("#btn_usclose"); + samplerspopup = $("#samplerscontainer"); + samplerslist = $("#samplerslist"); + samplers_accept = $("#btn_samplersaccept"); + samplers_close = $("#btn_samplersclose"); nspopup = $("#newgamecontainer"); ns_accept = $("#btn_nsaccept"); ns_close = $("#btn_nsclose"); @@ -1908,7 +1950,7 @@ $(document).ready(function(){ modelname = msg.modelname; } refreshTitle(); - connect_status.html("Connected to KoboldAI Process!"); + connect_status.html("Connected to KoboldAI!"); connect_status.removeClass("color_orange"); connect_status.addClass("color_green"); // Reset Menus @@ -2310,6 +2352,8 @@ $(document).ready(function(){ buildSPList(msg.data); } else if(msg.cmd == "buildus") { buildUSList(msg.data.unloaded, msg.data.loaded); + } else if(msg.cmd == "buildsamplers") { + buildSamplerList(msg.data); } else if(msg.cmd == "askforoverwrite") { // Show overwrite warning show([$(".saveasoverwrite")]); @@ -2436,6 +2480,20 @@ $(document).ready(function(){ }, 10); } + var samplers_click_handler = function(ev) { + setTimeout(function() { + if (samplers_dragging) { + return; + } + var target = $(ev.target).closest(".samplerslistitem"); + var next = target.parent().next().find(".samplerslistitem"); + if (!next.length) { + return; + } + next.parent().after(target.parent()); + }, 10); + } + // Make the userscripts menu sortable var us_sortable_settings = { placeholder: "ussortable-placeholder", @@ -2456,6 +2514,22 @@ $(document).ready(function(){ connectWith: "#uslistunloaded", }, us_sortable_settings)).on("click", ".uslistitem", us_click_handler); + // Make the samplers menu sortable + var samplers_sortable_settings = { + placeholder: "samplerssortable-placeholder", + start: function() { samplers_dragging = true; }, + stop: function() { samplers_dragging = false; }, + delay: 2, + cursor: "move", + tolerance: "pointer", + opacity: 0.21, + revert: 173, + scrollSensitivity: 64, + scrollSpeed: 10, + } + samplerslist.sortable($.extend({ + }, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler); + // Bind actions to UI buttons button_send.on("click", function(ev) { dosubmit(); @@ -2590,6 +2664,10 @@ $(document).ready(function(){ button_userscripts.on("click", function(ev) { socket.send({'cmd': 'uslistrequest', 'data': ''}); }); + + button_samplers.on("click", function(ev) { + socket.send({'cmd': 'samplerlistrequest', 'data': ''}); + }); load_close.on("click", function(ev) { hideLoadPopup(); @@ -2623,6 +2701,16 @@ $(document).ready(function(){ socket.send({'cmd': 'usload', 'data': ''}); hideUSPopup(); }); + + samplers_close.on("click", function(ev) { + hideSamplersPopup(); + }); + + samplers_accept.on("click", function(ev) { + hideMessage(); + socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()}); + hideSamplersPopup(); + }); button_newgame.on("click", function(ev) { if(connected) { diff --git a/static/custom.css b/static/custom.css index d70fd34e..640cb8db 100644 --- a/static/custom.css +++ b/static/custom.css @@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available { overflow-wrap: anywhere; } +#samplerspopup { + width: 300px; + background-color: #262626; + margin-top: 100px; +} + +@media (max-width: 768px) { + #samplerspopup { + width: 100%; + background-color: #262626; + margin-top: 100px; + } +} + +#samplerslist { + height: 300px; + overflow-y: scroll; + overflow-wrap: anywhere; +} + #nspopup { width: 350px; background-color: #262626; @@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover { background-color: #3bf723; } -.ussortable-placeholder { +.ussortable-placeholder, .samplerssortable-placeholder { height: 4px; background-color: #3bf723; } @@ -1340,7 +1360,7 @@ body.connected .popupfooter, .popupfooter.always-available { background-color: #688f1f; } -.uslistitem { +.uslistitem, .samplerslistitem { padding: 12px 10px 12px 10px; display: flex; flex-grow: 1; @@ -1352,11 +1372,11 @@ body.connected .popupfooter, .popupfooter.always-available { transition: background-color 0.25s ease-in; } -.uslistitemsub { +.uslistitemsub, .samplerslistitemsub { color: #ba9; } -.uslistitem:hover { +.uslistitem:hover, .samplerslistitem:hover { cursor: move; background-color: #688f1f; } diff --git a/templates/index.html b/templates/index.html index 690535f7..7ec9f66c 100644 --- a/templates/index.html +++ b/templates/index.html @@ -9,7 +9,7 @@ - + @@ -17,7 +17,7 @@ - + @@ -71,6 +71,9 @@ + @@ -299,6 +302,19 @@ +