diff --git a/aiserver.py b/aiserver.py index 35b20613..0ef805b9 100644 --- a/aiserver.py +++ b/aiserver.py @@ -106,7 +106,6 @@ model_menu = { ["Adventure Models", "adventurelist", "", True], ["Novel Models", "novellist", "", True], ["NSFW Models", "nsfwlist", "", True], - ["Chatbot Models", "chatlist", "", True], ["Untuned GPT-Neo/J", "gptneolist", "", True], ["Untuned Fairseq Dense", "fsdlist", "", True], ["Untuned OPT", "optlist", "", True], @@ -220,6 +219,7 @@ class vars: temp = 0.5 # Default generator temperature top_p = 0.9 # Default generator top_p top_k = 0 # Default generator top_k + top_a = 0.0 # Default generator top-a tfs = 1.0 # Default generator tfs (tail-free sampling) typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create @@ -315,6 +315,7 @@ class vars: acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor + sampler_order = utils.default_sampler_order.copy() chatmode = False chatname = "You" adventure = False @@ -647,6 +648,8 @@ def loadmodelsettings(): vars.badwordsids = js["badwordsids"] if("nobreakmodel" in js): vars.nobreakmodel = js["nobreakmodel"] + if("sampler_order" in js): + vars.sampler_order = js["sampler_order"] if("temp" in js): vars.temp = js["temp"] if("top_p" in js): @@ -657,6 +660,8 @@ def loadmodelsettings(): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -688,11 +693,13 @@ def savesettings(): js = {} js["apikey"] = vars.apikey js["andepth"] = vars.andepth + js["sampler_order"] = vars.sampler_order js["temp"] = vars.temp js["top_p"] = vars.top_p js["top_k"] = vars.top_k js["tfs"] = vars.tfs js["typical"] = vars.typical + js["top_a"] = vars.top_a js["rep_pen"] = vars.rep_pen js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_range"] = vars.rep_pen_range @@ -763,6 +770,8 @@ def processsettings(js): vars.apikey = js["apikey"] if("andepth" in js): vars.andepth = js["andepth"] + if("sampler_order" in js): + vars.sampler_order = js["sampler_order"] if("temp" in js): vars.temp = js["temp"] if("top_p" in js): @@ -773,6 +782,8 @@ def processsettings(js): vars.tfs = js["tfs"] if("typical" in js): vars.typical = js["typical"] + if("top_a" in js): + vars.top_a = js["top_a"] if("rep_pen" in js): vars.rep_pen = js["rep_pen"] if("rep_pen_slope" in js): @@ -1268,7 +1279,7 @@ def patch_transformers(): # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper + from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ @@ -1288,6 +1299,7 @@ def patch_transformers(): cls.__call__ = new_call dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) @@ -1331,14 +1343,23 @@ def patch_transformers(): new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor + class KoboldLogitsWarperList(LogitsProcessorList): + def __init__(self, beams: int = 1, **kwargs): + self.__warper_list: List[LogitsWarper] = [] + self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): + for k in vars.sampler_order: + scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) + return scores + def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: - warper_list = LogitsProcessorList() - warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TemperatureLogitsWarper(temperature=0.5)) - return warper_list + return KoboldLogitsWarperList(beams=beams) def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None @@ -1957,11 +1978,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" def tpumtjgenerate_settings_callback() -> dict: return { + "sampler_order": vars.sampler_order, "top_p": float(vars.top_p), "temp": float(vars.temp), "top_k": int(vars.top_k), "tfs": float(vars.tfs), "typical": float(vars.typical), + "top_a": float(vars.top_a), "repetition_penalty": float(vars.rep_pen), "rpslope": float(vars.rep_pen_slope), "rprange": int(vars.rep_pen_range), @@ -2384,6 +2407,7 @@ def lua_has_setting(setting): "settopk", "settfs", "settypical", + "settopa", "setreppen", "setreppenslope", "setreppenrange", @@ -2403,6 +2427,7 @@ def lua_has_setting(setting): "top_k", "tfs", "typical", + "topa", "reppen", "reppenslope", "reppenrange", @@ -2437,6 +2462,7 @@ def lua_get_setting(setting): if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("settypical", "typical")): return vars.typical + if(setting in ("settopa", "topa")): return vars.top_a if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range @@ -2472,6 +2498,7 @@ def lua_set_setting(setting, v): if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("settypical", "typical")): vars.typical = v + if(setting in ("settopa", "topa")): vars.top_a = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v @@ -2862,6 +2889,11 @@ def get_message(msg): emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True) settingschanged() refresh_settings() + elif(msg['cmd'] == 'settopa'): + vars.top_a = float(msg['data']) + emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True) + settingschanged() + refresh_settings() elif(msg['cmd'] == 'setreppen'): vars.rep_pen = float(msg['data']) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) @@ -3015,6 +3047,8 @@ def get_message(msg): elif(msg['cmd'] == 'uslistrequest'): unloaded, loaded = getuslist() emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}}) + elif(msg['cmd'] == 'samplerlistrequest'): + emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order}) elif(msg['cmd'] == 'usloaded'): vars.userscripts = [] for userscript in msg['data']: @@ -3028,6 +3062,16 @@ def get_message(msg): load_lua_scripts() unloaded, loaded = getuslist() sendUSStatItems() + elif(msg['cmd'] == 'samplers'): + sampler_order = msg["data"] + if(not isinstance(sampler_order, list)): + raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}") + if(len(sampler_order) != len(vars.sampler_order)): + raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}") + if(not all(isinstance(e, int) for e in sampler_order)): + raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element") + vars.sampler_order = sampler_order + settingschanged() elif(msg['cmd'] == 'list_model'): sendModelSelection(menu=msg['data']) elif(msg['cmd'] == 'load_model'): @@ -3988,6 +4032,7 @@ def sendtocolab(txt, min, max): 'top_k': vars.top_k, 'tfs': vars.tfs, 'typical': vars.typical, + 'topa': vars.top_a, 'numseqs': vars.numseqs, 'retfultxt': False } @@ -4125,12 +4170,14 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): top_k=vars.top_k, tfs=vars.tfs, typical=vars.typical, + top_a=vars.top_a, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, rpslope=vars.rep_pen_slope, rprange=vars.rep_pen_range, soft_embeddings=vars.sp, soft_tokens=soft_tokens, + sampler_order=vars.sampler_order, ) past = genout for i in range(vars.numseqs): @@ -4311,6 +4358,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True) + emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) @@ -4887,6 +4935,7 @@ def oairequest(txt, min, max): 'prompt': txt, 'max_tokens': vars.genamt, 'temperature': vars.temp, + 'top_a': vars.top_a, 'top_p': vars.top_p, 'top_k': vars.top_k, 'tfs': vars.tfs, diff --git a/bridge.lua b/bridge.lua index ed0941c6..fc6c8823 100644 --- a/bridge.lua +++ b/bridge.lua @@ -867,6 +867,7 @@ return function(_python, _bridged) ---@field settopk integer ---@field settfs number ---@field settypical number + ---@field settopa number ---@field setreppen number ---@field setreppenslope number ---@field setreppenrange number @@ -884,6 +885,7 @@ return function(_python, _bridged) ---@field top_k integer ---@field tfs number ---@field typical number + ---@field topa number ---@field reppen number ---@field reppenslope number ---@field reppenrange number diff --git a/gensettings.py b/gensettings.py index e8d4e566..b3007c91 100644 --- a/gensettings.py +++ b/gensettings.py @@ -64,6 +64,17 @@ gensettingstf = [ "step": 0.05, "default": 1.0, "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." + }, + { + "uitype": "slider", + "unit": "float", + "label": "Top a Sampling", + "id": "settopa", + "min": 0.0, + "max": 1.0, + "step": 0.01, + "default": 0.0, + "tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect." }, { "uitype": "slider", diff --git a/static/application.js b/static/application.js index 9470cee9..286a8992 100644 --- a/static/application.js +++ b/static/application.js @@ -21,6 +21,7 @@ var button_settings; var button_format; var button_softprompt; var button_userscripts; +var button_samplers; var button_mode; var button_mode_label; var button_send; @@ -112,6 +113,9 @@ var do_clear_ent = false; // Whether or not an entry in the Userscripts menu is being dragged var us_dragging = false; +// Whether or not an entry in the Samplers menu is being dragged +var samplers_dragging = false; + // Display vars var allowtoggle = false; var formatcount = 0; @@ -997,6 +1001,16 @@ function hideUSPopup() { spcontent.html(""); } +function showSamplersPopup() { + samplerspopup.removeClass("hidden"); + samplerspopup.addClass("flex"); +} + +function hideSamplersPopup() { + samplerspopup.removeClass("flex"); + samplerspopup.addClass("hidden"); +} + function buildLoadModelList(ar, menu, breadcrumbs) { disableButtons([load_model_accept]); @@ -1207,6 +1221,29 @@ function buildUSList(unloaded, loaded) { } } +function buildSamplerList(samplers) { + samplerslist.html(""); + showSamplersPopup(); + var i; + var samplers_lookup_table = [ + "Top-k Sampling", + "Top-a Sampling", + "Top-p Sampling", + "Tail-free Sampling", + "Typical Sampling", + "Temperature", + ] + for(i=0; i\ +
\ +
\ +
"+samplers_lookup_table[samplers[i]]+"
\ +
\ +
\ + "); + } +} + function highlightLoadLine(ref) { $("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected"); $("#loadmodellistcontent > div > div.popuplistselected").removeClass("popuplistselected"); @@ -1963,6 +2000,7 @@ $(document).ready(function(){ button_format = $('#btn_format'); button_softprompt = $("#btn_softprompt"); button_userscripts= $("#btn_userscripts"); + button_samplers = $("#btn_samplers"); button_mode = $('#btnmode') button_mode_label = $('#btnmode_label') button_send = $('#btnsend'); @@ -2015,6 +2053,10 @@ $(document).ready(function(){ usloaded = $("#uslistloaded"); us_accept = $("#btn_usaccept"); us_close = $("#btn_usclose"); + samplerspopup = $("#samplerscontainer"); + samplerslist = $("#samplerslist"); + samplers_accept = $("#btn_samplersaccept"); + samplers_close = $("#btn_samplersclose"); nspopup = $("#newgamecontainer"); ns_accept = $("#btn_nsaccept"); ns_close = $("#btn_nsclose"); @@ -2038,7 +2080,7 @@ $(document).ready(function(){ modelname = msg.modelname; } refreshTitle(); - connect_status.html("Connected to KoboldAI Process!"); + connect_status.html("Connected to KoboldAI!"); connect_status.removeClass("color_orange"); connect_status.addClass("color_green"); // Reset Menus @@ -2231,6 +2273,10 @@ $(document).ready(function(){ // Send current typical value to input $("#settypicalcur").val(msg.data); $("#settypical").val(parseFloat(msg.data)).trigger("change"); + } else if(msg.cmd == "updatetopa") { + // Send current top a value to input + $("#settopacur").val(msg.data); + $("#settopa").val(parseFloat(msg.data)).trigger("change"); } else if(msg.cmd == "updatereppen") { // Send current rep pen value to input $("#setreppencur").val(msg.data); @@ -2270,6 +2316,9 @@ $(document).ready(function(){ } else if(msg.cmd == "setlabeltypical") { // Update setting label with value from server $("#settypicalcur").val(msg.data); + } else if(msg.cmd == "setlabeltypical") { + // Update setting label with value from server + $("#settopa").val(msg.data); } else if(msg.cmd == "setlabelreppen") { // Update setting label with value from server $("#setreppencur").val(msg.data); @@ -2440,6 +2489,8 @@ $(document).ready(function(){ buildSPList(msg.data); } else if(msg.cmd == "buildus") { buildUSList(msg.data.unloaded, msg.data.loaded); + } else if(msg.cmd == "buildsamplers") { + buildSamplerList(msg.data); } else if(msg.cmd == "askforoverwrite") { // Show overwrite warning show([$(".saveasoverwrite")]); @@ -2648,6 +2699,20 @@ $(document).ready(function(){ }, 10); } + var samplers_click_handler = function(ev) { + setTimeout(function() { + if (samplers_dragging) { + return; + } + var target = $(ev.target).closest(".samplerslistitem"); + var next = target.parent().next().find(".samplerslistitem"); + if (!next.length) { + return; + } + next.parent().after(target.parent()); + }, 10); + } + // Make the userscripts menu sortable var us_sortable_settings = { placeholder: "ussortable-placeholder", @@ -2668,6 +2733,22 @@ $(document).ready(function(){ connectWith: "#uslistunloaded", }, us_sortable_settings)).on("click", ".uslistitem", us_click_handler); + // Make the samplers menu sortable + var samplers_sortable_settings = { + placeholder: "samplerssortable-placeholder", + start: function() { samplers_dragging = true; }, + stop: function() { samplers_dragging = false; }, + delay: 2, + cursor: "move", + tolerance: "pointer", + opacity: 0.21, + revert: 173, + scrollSensitivity: 64, + scrollSpeed: 10, + } + samplerslist.sortable($.extend({ + }, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler); + // Bind actions to UI buttons button_send.on("click", function(ev) { dosubmit(); @@ -2802,6 +2883,10 @@ $(document).ready(function(){ button_userscripts.on("click", function(ev) { socket.send({'cmd': 'uslistrequest', 'data': ''}); }); + + button_samplers.on("click", function(ev) { + socket.send({'cmd': 'samplerlistrequest', 'data': ''}); + }); load_close.on("click", function(ev) { hideLoadPopup(); @@ -2858,6 +2943,16 @@ $(document).ready(function(){ socket.send({'cmd': 'usload', 'data': ''}); hideUSPopup(); }); + + samplers_close.on("click", function(ev) { + hideSamplersPopup(); + }); + + samplers_accept.on("click", function(ev) { + hideMessage(); + socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()}); + hideSamplersPopup(); + }); button_loadmodel.on("click", function(ev) { showLoadModelPopup(); diff --git a/static/custom.css b/static/custom.css index 078ae7d2..ea22bdf3 100644 --- a/static/custom.css +++ b/static/custom.css @@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available { overflow-wrap: anywhere; } +#samplerspopup { + width: 300px; + background-color: #262626; + margin-top: 100px; +} + +@media (max-width: 768px) { + #samplerspopup { + width: 100%; + background-color: #262626; + margin-top: 100px; + } +} + +#samplerslist { + height: 300px; + overflow-y: scroll; + overflow-wrap: anywhere; +} + #nspopup { width: 350px; background-color: #262626; @@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover { background-color: #3bf723; } -.ussortable-placeholder { +.ussortable-placeholder, .samplerssortable-placeholder { height: 4px; background-color: #3bf723; } @@ -1362,7 +1382,7 @@ body.connected .popupfooter, .popupfooter.always-available { background-color: #688f1f; } -.uslistitem { +.uslistitem, .samplerslistitem { padding: 12px 10px 12px 10px; display: flex; flex-grow: 1; @@ -1374,11 +1394,11 @@ body.connected .popupfooter, .popupfooter.always-available { transition: background-color 0.25s ease-in; } -.uslistitemsub { +.uslistitemsub, .samplerslistitemsub { color: #ba9; } -.uslistitem:hover { +.uslistitem:hover, .samplerslistitem:hover { cursor: move; background-color: #688f1f; } diff --git a/templates/index.html b/templates/index.html index 2050001a..b76f512f 100644 --- a/templates/index.html +++ b/templates/index.html @@ -9,7 +9,7 @@ - + @@ -17,7 +17,7 @@ - + @@ -81,6 +81,9 @@ + @@ -363,6 +366,19 @@ +