Merge pull request #148 from VE-FORBRYDERNE/overhaul-merge

Merge united into overhaul
This commit is contained in:
henk717 2022-06-15 00:56:14 +02:00 committed by GitHub
commit f3eb7cba5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 298 additions and 36 deletions

View File

@ -106,7 +106,6 @@ model_menu = {
["Adventure Models", "adventurelist", "", True], ["Adventure Models", "adventurelist", "", True],
["Novel Models", "novellist", "", True], ["Novel Models", "novellist", "", True],
["NSFW Models", "nsfwlist", "", True], ["NSFW Models", "nsfwlist", "", True],
["Chatbot Models", "chatlist", "", True],
["Untuned GPT-Neo/J", "gptneolist", "", True], ["Untuned GPT-Neo/J", "gptneolist", "", True],
["Untuned Fairseq Dense", "fsdlist", "", True], ["Untuned Fairseq Dense", "fsdlist", "", True],
["Untuned OPT", "optlist", "", True], ["Untuned OPT", "optlist", "", True],
@ -220,6 +219,7 @@ class vars:
temp = 0.5 # Default generator temperature temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k top_k = 0 # Default generator top_k
top_a = 0.0 # Default generator top-a
tfs = 1.0 # Default generator tfs (tail-free sampling) tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create numseqs = 1 # Number of sequences to ask the generator to create
@ -315,6 +315,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor
sampler_order = utils.default_sampler_order.copy()
chatmode = False chatmode = False
chatname = "You" chatname = "You"
adventure = False adventure = False
@ -647,6 +648,8 @@ def loadmodelsettings():
vars.badwordsids = js["badwordsids"] vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js): if("nobreakmodel" in js):
vars.nobreakmodel = js["nobreakmodel"] vars.nobreakmodel = js["nobreakmodel"]
if("sampler_order" in js):
vars.sampler_order = js["sampler_order"]
if("temp" in js): if("temp" in js):
vars.temp = js["temp"] vars.temp = js["temp"]
if("top_p" in js): if("top_p" in js):
@ -657,6 +660,8 @@ def loadmodelsettings():
vars.tfs = js["tfs"] vars.tfs = js["tfs"]
if("typical" in js): if("typical" in js):
vars.typical = js["typical"] vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js): if("rep_pen" in js):
vars.rep_pen = js["rep_pen"] vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js): if("rep_pen_slope" in js):
@ -688,11 +693,13 @@ def savesettings():
js = {} js = {}
js["apikey"] = vars.apikey js["apikey"] = vars.apikey
js["andepth"] = vars.andepth js["andepth"] = vars.andepth
js["sampler_order"] = vars.sampler_order
js["temp"] = vars.temp js["temp"] = vars.temp
js["top_p"] = vars.top_p js["top_p"] = vars.top_p
js["top_k"] = vars.top_k js["top_k"] = vars.top_k
js["tfs"] = vars.tfs js["tfs"] = vars.tfs
js["typical"] = vars.typical js["typical"] = vars.typical
js["top_a"] = vars.top_a
js["rep_pen"] = vars.rep_pen js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range js["rep_pen_range"] = vars.rep_pen_range
@ -763,6 +770,8 @@ def processsettings(js):
vars.apikey = js["apikey"] vars.apikey = js["apikey"]
if("andepth" in js): if("andepth" in js):
vars.andepth = js["andepth"] vars.andepth = js["andepth"]
if("sampler_order" in js):
vars.sampler_order = js["sampler_order"]
if("temp" in js): if("temp" in js):
vars.temp = js["temp"] vars.temp = js["temp"]
if("top_p" in js): if("top_p" in js):
@ -773,6 +782,8 @@ def processsettings(js):
vars.tfs = js["tfs"] vars.tfs = js["tfs"]
if("typical" in js): if("typical" in js):
vars.typical = js["typical"] vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js): if("rep_pen" in js):
vars.rep_pen = js["rep_pen"] vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js): if("rep_pen_slope" in js):
@ -1268,7 +1279,7 @@ def patch_transformers():
# Patch transformers to use our custom logit warpers # Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None): def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__ old_call = cls.__call__
@ -1288,6 +1299,7 @@ def patch_transformers():
cls.__call__ = new_call cls.__call__ = new_call
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0) dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
@ -1331,14 +1343,23 @@ def patch_transformers():
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs):
self.__warper_list: List[LogitsWarper] = []
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
for k in vars.sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
return scores
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
warper_list = LogitsProcessorList() return KoboldLogitsWarperList(beams=beams)
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list
def new_sample(self, *args, **kwargs): def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None assert kwargs.pop("logits_warper", None) is not None
@ -1957,11 +1978,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
def tpumtjgenerate_settings_callback() -> dict: def tpumtjgenerate_settings_callback() -> dict:
return { return {
"sampler_order": vars.sampler_order,
"top_p": float(vars.top_p), "top_p": float(vars.top_p),
"temp": float(vars.temp), "temp": float(vars.temp),
"top_k": int(vars.top_k), "top_k": int(vars.top_k),
"tfs": float(vars.tfs), "tfs": float(vars.tfs),
"typical": float(vars.typical), "typical": float(vars.typical),
"top_a": float(vars.top_a),
"repetition_penalty": float(vars.rep_pen), "repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope), "rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range), "rprange": int(vars.rep_pen_range),
@ -2384,6 +2407,7 @@ def lua_has_setting(setting):
"settopk", "settopk",
"settfs", "settfs",
"settypical", "settypical",
"settopa",
"setreppen", "setreppen",
"setreppenslope", "setreppenslope",
"setreppenrange", "setreppenrange",
@ -2403,6 +2427,7 @@ def lua_has_setting(setting):
"top_k", "top_k",
"tfs", "tfs",
"typical", "typical",
"topa",
"reppen", "reppen",
"reppenslope", "reppenslope",
"reppenrange", "reppenrange",
@ -2437,6 +2462,7 @@ def lua_get_setting(setting):
if(setting in ("settopk", "topk", "top_k")): return vars.top_k if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("settopa", "topa")): return vars.top_a
if(setting in ("setreppen", "reppen")): return vars.rep_pen if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -2472,6 +2498,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopk", "topk")): vars.top_k = v if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("settopa", "topa")): vars.top_a = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2862,6 +2889,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True) emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged() settingschanged()
refresh_settings() refresh_settings()
elif(msg['cmd'] == 'settopa'):
vars.top_a = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'): elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data']) vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True) emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3015,6 +3047,8 @@ def get_message(msg):
elif(msg['cmd'] == 'uslistrequest'): elif(msg['cmd'] == 'uslistrequest'):
unloaded, loaded = getuslist() unloaded, loaded = getuslist()
emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}}) emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}})
elif(msg['cmd'] == 'samplerlistrequest'):
emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order})
elif(msg['cmd'] == 'usloaded'): elif(msg['cmd'] == 'usloaded'):
vars.userscripts = [] vars.userscripts = []
for userscript in msg['data']: for userscript in msg['data']:
@ -3028,6 +3062,16 @@ def get_message(msg):
load_lua_scripts() load_lua_scripts()
unloaded, loaded = getuslist() unloaded, loaded = getuslist()
sendUSStatItems() sendUSStatItems()
elif(msg['cmd'] == 'samplers'):
sampler_order = msg["data"]
if(not isinstance(sampler_order, list)):
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
if(len(sampler_order) != len(vars.sampler_order)):
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
if(not all(isinstance(e, int) for e in sampler_order)):
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
vars.sampler_order = sampler_order
settingschanged()
elif(msg['cmd'] == 'list_model'): elif(msg['cmd'] == 'list_model'):
sendModelSelection(menu=msg['data']) sendModelSelection(menu=msg['data'])
elif(msg['cmd'] == 'load_model'): elif(msg['cmd'] == 'load_model'):
@ -3988,6 +4032,7 @@ def sendtocolab(txt, min, max):
'top_k': vars.top_k, 'top_k': vars.top_k,
'tfs': vars.tfs, 'tfs': vars.tfs,
'typical': vars.typical, 'typical': vars.typical,
'topa': vars.top_a,
'numseqs': vars.numseqs, 'numseqs': vars.numseqs,
'retfultxt': False 'retfultxt': False
} }
@ -4125,12 +4170,14 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_k=vars.top_k, top_k=vars.top_k,
tfs=vars.tfs, tfs=vars.tfs,
typical=vars.typical, typical=vars.typical,
top_a=vars.top_a,
numseqs=vars.numseqs, numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen, repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope, rpslope=vars.rep_pen_slope,
rprange=vars.rep_pen_range, rprange=vars.rep_pen_range,
soft_embeddings=vars.sp, soft_embeddings=vars.sp,
soft_tokens=soft_tokens, soft_tokens=soft_tokens,
sampler_order=vars.sampler_order,
) )
past = genout past = genout
for i in range(vars.numseqs): for i in range(vars.numseqs):
@ -4311,6 +4358,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True) emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True) emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True) emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True) emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True) emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True) emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
@ -4887,6 +4935,7 @@ def oairequest(txt, min, max):
'prompt': txt, 'prompt': txt,
'max_tokens': vars.genamt, 'max_tokens': vars.genamt,
'temperature': vars.temp, 'temperature': vars.temp,
'top_a': vars.top_a,
'top_p': vars.top_p, 'top_p': vars.top_p,
'top_k': vars.top_k, 'top_k': vars.top_k,
'tfs': vars.tfs, 'tfs': vars.tfs,

View File

@ -867,6 +867,7 @@ return function(_python, _bridged)
---@field settopk integer ---@field settopk integer
---@field settfs number ---@field settfs number
---@field settypical number ---@field settypical number
---@field settopa number
---@field setreppen number ---@field setreppen number
---@field setreppenslope number ---@field setreppenslope number
---@field setreppenrange number ---@field setreppenrange number
@ -884,6 +885,7 @@ return function(_python, _bridged)
---@field top_k integer ---@field top_k integer
---@field tfs number ---@field tfs number
---@field typical number ---@field typical number
---@field topa number
---@field reppen number ---@field reppen number
---@field reppenslope number ---@field reppenslope number
---@field reppenrange number ---@field reppenrange number

View File

@ -64,6 +64,17 @@ gensettingstf = [
"step": 0.05, "step": 0.05,
"default": 1.0, "default": 1.0,
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect." "tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
},
{
"uitype": "slider",
"unit": "float",
"label": "Top a Sampling",
"id": "settopa",
"min": 0.0,
"max": 1.0,
"step": 0.01,
"default": 0.0,
"tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect."
}, },
{ {
"uitype": "slider", "uitype": "slider",

View File

@ -21,6 +21,7 @@ var button_settings;
var button_format; var button_format;
var button_softprompt; var button_softprompt;
var button_userscripts; var button_userscripts;
var button_samplers;
var button_mode; var button_mode;
var button_mode_label; var button_mode_label;
var button_send; var button_send;
@ -112,6 +113,9 @@ var do_clear_ent = false;
// Whether or not an entry in the Userscripts menu is being dragged // Whether or not an entry in the Userscripts menu is being dragged
var us_dragging = false; var us_dragging = false;
// Whether or not an entry in the Samplers menu is being dragged
var samplers_dragging = false;
// Display vars // Display vars
var allowtoggle = false; var allowtoggle = false;
var formatcount = 0; var formatcount = 0;
@ -997,6 +1001,16 @@ function hideUSPopup() {
spcontent.html(""); spcontent.html("");
} }
function showSamplersPopup() {
samplerspopup.removeClass("hidden");
samplerspopup.addClass("flex");
}
function hideSamplersPopup() {
samplerspopup.removeClass("flex");
samplerspopup.addClass("hidden");
}
function buildLoadModelList(ar, menu, breadcrumbs) { function buildLoadModelList(ar, menu, breadcrumbs) {
disableButtons([load_model_accept]); disableButtons([load_model_accept]);
@ -1207,6 +1221,29 @@ function buildUSList(unloaded, loaded) {
} }
} }
function buildSamplerList(samplers) {
samplerslist.html("");
showSamplersPopup();
var i;
var samplers_lookup_table = [
"Top-k Sampling",
"Top-a Sampling",
"Top-p Sampling",
"Tail-free Sampling",
"Typical Sampling",
"Temperature",
]
for(i=0; i<samplers.length; i++) {
samplerslist.append("<div class=\"flex\">\
<div class=\"samplerslistitem flex-row-container\" sid=\""+samplers[i]+"\">\
<div class=\"flex-row\">\
<div>"+samplers_lookup_table[samplers[i]]+"</div>\
</div>\
</div>\
</div>");
}
}
function highlightLoadLine(ref) { function highlightLoadLine(ref) {
$("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected"); $("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected");
$("#loadmodellistcontent > div > div.popuplistselected").removeClass("popuplistselected"); $("#loadmodellistcontent > div > div.popuplistselected").removeClass("popuplistselected");
@ -1963,6 +2000,7 @@ $(document).ready(function(){
button_format = $('#btn_format'); button_format = $('#btn_format');
button_softprompt = $("#btn_softprompt"); button_softprompt = $("#btn_softprompt");
button_userscripts= $("#btn_userscripts"); button_userscripts= $("#btn_userscripts");
button_samplers = $("#btn_samplers");
button_mode = $('#btnmode') button_mode = $('#btnmode')
button_mode_label = $('#btnmode_label') button_mode_label = $('#btnmode_label')
button_send = $('#btnsend'); button_send = $('#btnsend');
@ -2015,6 +2053,10 @@ $(document).ready(function(){
usloaded = $("#uslistloaded"); usloaded = $("#uslistloaded");
us_accept = $("#btn_usaccept"); us_accept = $("#btn_usaccept");
us_close = $("#btn_usclose"); us_close = $("#btn_usclose");
samplerspopup = $("#samplerscontainer");
samplerslist = $("#samplerslist");
samplers_accept = $("#btn_samplersaccept");
samplers_close = $("#btn_samplersclose");
nspopup = $("#newgamecontainer"); nspopup = $("#newgamecontainer");
ns_accept = $("#btn_nsaccept"); ns_accept = $("#btn_nsaccept");
ns_close = $("#btn_nsclose"); ns_close = $("#btn_nsclose");
@ -2038,7 +2080,7 @@ $(document).ready(function(){
modelname = msg.modelname; modelname = msg.modelname;
} }
refreshTitle(); refreshTitle();
connect_status.html("<b>Connected to KoboldAI Process!</b>"); connect_status.html("<b>Connected to KoboldAI!</b>");
connect_status.removeClass("color_orange"); connect_status.removeClass("color_orange");
connect_status.addClass("color_green"); connect_status.addClass("color_green");
// Reset Menus // Reset Menus
@ -2231,6 +2273,10 @@ $(document).ready(function(){
// Send current typical value to input // Send current typical value to input
$("#settypicalcur").val(msg.data); $("#settypicalcur").val(msg.data);
$("#settypical").val(parseFloat(msg.data)).trigger("change"); $("#settypical").val(parseFloat(msg.data)).trigger("change");
} else if(msg.cmd == "updatetopa") {
// Send current top a value to input
$("#settopacur").val(msg.data);
$("#settopa").val(parseFloat(msg.data)).trigger("change");
} else if(msg.cmd == "updatereppen") { } else if(msg.cmd == "updatereppen") {
// Send current rep pen value to input // Send current rep pen value to input
$("#setreppencur").val(msg.data); $("#setreppencur").val(msg.data);
@ -2270,6 +2316,9 @@ $(document).ready(function(){
} else if(msg.cmd == "setlabeltypical") { } else if(msg.cmd == "setlabeltypical") {
// Update setting label with value from server // Update setting label with value from server
$("#settypicalcur").val(msg.data); $("#settypicalcur").val(msg.data);
} else if(msg.cmd == "setlabeltypical") {
// Update setting label with value from server
$("#settopa").val(msg.data);
} else if(msg.cmd == "setlabelreppen") { } else if(msg.cmd == "setlabelreppen") {
// Update setting label with value from server // Update setting label with value from server
$("#setreppencur").val(msg.data); $("#setreppencur").val(msg.data);
@ -2440,6 +2489,8 @@ $(document).ready(function(){
buildSPList(msg.data); buildSPList(msg.data);
} else if(msg.cmd == "buildus") { } else if(msg.cmd == "buildus") {
buildUSList(msg.data.unloaded, msg.data.loaded); buildUSList(msg.data.unloaded, msg.data.loaded);
} else if(msg.cmd == "buildsamplers") {
buildSamplerList(msg.data);
} else if(msg.cmd == "askforoverwrite") { } else if(msg.cmd == "askforoverwrite") {
// Show overwrite warning // Show overwrite warning
show([$(".saveasoverwrite")]); show([$(".saveasoverwrite")]);
@ -2648,6 +2699,20 @@ $(document).ready(function(){
}, 10); }, 10);
} }
var samplers_click_handler = function(ev) {
setTimeout(function() {
if (samplers_dragging) {
return;
}
var target = $(ev.target).closest(".samplerslistitem");
var next = target.parent().next().find(".samplerslistitem");
if (!next.length) {
return;
}
next.parent().after(target.parent());
}, 10);
}
// Make the userscripts menu sortable // Make the userscripts menu sortable
var us_sortable_settings = { var us_sortable_settings = {
placeholder: "ussortable-placeholder", placeholder: "ussortable-placeholder",
@ -2668,6 +2733,22 @@ $(document).ready(function(){
connectWith: "#uslistunloaded", connectWith: "#uslistunloaded",
}, us_sortable_settings)).on("click", ".uslistitem", us_click_handler); }, us_sortable_settings)).on("click", ".uslistitem", us_click_handler);
// Make the samplers menu sortable
var samplers_sortable_settings = {
placeholder: "samplerssortable-placeholder",
start: function() { samplers_dragging = true; },
stop: function() { samplers_dragging = false; },
delay: 2,
cursor: "move",
tolerance: "pointer",
opacity: 0.21,
revert: 173,
scrollSensitivity: 64,
scrollSpeed: 10,
}
samplerslist.sortable($.extend({
}, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler);
// Bind actions to UI buttons // Bind actions to UI buttons
button_send.on("click", function(ev) { button_send.on("click", function(ev) {
dosubmit(); dosubmit();
@ -2803,6 +2884,10 @@ $(document).ready(function(){
socket.send({'cmd': 'uslistrequest', 'data': ''}); socket.send({'cmd': 'uslistrequest', 'data': ''});
}); });
button_samplers.on("click", function(ev) {
socket.send({'cmd': 'samplerlistrequest', 'data': ''});
});
load_close.on("click", function(ev) { load_close.on("click", function(ev) {
hideLoadPopup(); hideLoadPopup();
}); });
@ -2859,6 +2944,16 @@ $(document).ready(function(){
hideUSPopup(); hideUSPopup();
}); });
samplers_close.on("click", function(ev) {
hideSamplersPopup();
});
samplers_accept.on("click", function(ev) {
hideMessage();
socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()});
hideSamplersPopup();
});
button_loadmodel.on("click", function(ev) { button_loadmodel.on("click", function(ev) {
showLoadModelPopup(); showLoadModelPopup();
socket.send({'cmd': 'list_model', 'data': 'mainmenu'}); socket.send({'cmd': 'list_model', 'data': 'mainmenu'});

View File

@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available {
overflow-wrap: anywhere; overflow-wrap: anywhere;
} }
#samplerspopup {
width: 300px;
background-color: #262626;
margin-top: 100px;
}
@media (max-width: 768px) {
#samplerspopup {
width: 100%;
background-color: #262626;
margin-top: 100px;
}
}
#samplerslist {
height: 300px;
overflow-y: scroll;
overflow-wrap: anywhere;
}
#nspopup { #nspopup {
width: 350px; width: 350px;
background-color: #262626; background-color: #262626;
@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover {
background-color: #3bf723; background-color: #3bf723;
} }
.ussortable-placeholder { .ussortable-placeholder, .samplerssortable-placeholder {
height: 4px; height: 4px;
background-color: #3bf723; background-color: #3bf723;
} }
@ -1362,7 +1382,7 @@ body.connected .popupfooter, .popupfooter.always-available {
background-color: #688f1f; background-color: #688f1f;
} }
.uslistitem { .uslistitem, .samplerslistitem {
padding: 12px 10px 12px 10px; padding: 12px 10px 12px 10px;
display: flex; display: flex;
flex-grow: 1; flex-grow: 1;
@ -1374,11 +1394,11 @@ body.connected .popupfooter, .popupfooter.always-available {
transition: background-color 0.25s ease-in; transition: background-color 0.25s ease-in;
} }
.uslistitemsub { .uslistitemsub, .samplerslistitemsub {
color: #ba9; color: #ba9;
} }
.uslistitem:hover { .uslistitem:hover, .samplerslistitem:hover {
cursor: move; cursor: move;
background-color: #688f1f; background-color: #688f1f;
} }

View File

@ -9,7 +9,7 @@
<link rel="stylesheet" href="static/bootstrap.min.css"> <link rel="stylesheet" href="static/bootstrap.min.css">
<link rel="stylesheet" href="static/bootstrap-toggle.min.css"> <link rel="stylesheet" href="static/bootstrap-toggle.min.css">
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css"> <link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
<link rel="stylesheet" href="static/custom.css?ver=1.18b"> <link rel="stylesheet" href="static/custom.css?ver=1.18c">
<script src="static/jquery-3.6.0.min.js"></script> <script src="static/jquery-3.6.0.min.js"></script>
<script src="static/jquery-ui.sortable.min.js"></script> <script src="static/jquery-ui.sortable.min.js"></script>
@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script> <script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script> <script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script> <script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.18c"></script> <script src="static/application.js?ver=1.18e"></script>
<script src="static/favicon.js"></script> <script src="static/favicon.js"></script>
</head> </head>
<body> <body>
@ -81,6 +81,9 @@
<li class="nav-item"> <li class="nav-item">
<a class="nav-link" href="#" id="btn_format">Formatting</a> <a class="nav-link" href="#" id="btn_format">Formatting</a>
</li> </li>
<li class="nav-item">
<a class="nav-link" href="#" id="btn_samplers">Samplers</a>
</li>
<li class="nav-item"> <li class="nav-item">
<a class="nav-link" href="#" id="btn_userscripts">Userscripts</a> <a class="nav-link" href="#" id="btn_userscripts">Userscripts</a>
</li> </li>
@ -363,6 +366,19 @@
</div> </div>
</div> </div>
</div> </div>
<div class="popupcontainer hidden" id="samplerscontainer">
<div id="samplerspopup">
<div class="popuptitlebar">
<div class="popuptitletext">Drag-and-drop to change the order in which the samplers are applied</div>
</div>
<div id="samplerslist">
</div>
<div class="popupfooter">
<button type="button" class="btn btn-primary" id="btn_samplersaccept">Save</button>
<button type="button" class="btn btn-primary" id="btn_samplersclose">Cancel</button>
</div>
</div>
</div>
<div class="popupcontainer hidden" id="loadcontainerdelete"> <div class="popupcontainer hidden" id="loadcontainerdelete">
<div id="loadpopupdelete"> <div id="loadpopupdelete">
<div class="popuptitlebar"> <div class="popuptitlebar">

View File

@ -65,11 +65,13 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List
def settings_callback() -> dict: def settings_callback() -> dict:
return { return {
"sampler_order": utils.default_sampler_order.copy(),
"top_p": 0.9, "top_p": 0.9,
"temp": 0.5, "temp": 0.5,
"top_k": 0, "top_k": 0,
"tfs": 1.0, "tfs": 1.0,
"typical": 1.0, "typical": 1.0,
"top_a": 0.0,
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
"rpslope": 0.0, "rpslope": 0.0,
"rprange": 0, "rprange": 0,
@ -158,10 +160,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits logits[tokens] = penalty_logits
return logits return logits
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
''' '''
This gets called by generate_loop_fn to apply a series of 5 filters This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-p, then TFS, then typical, then temperature) to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits before picking one token using the modified logits
''' '''
# Top-k (keep only the k tokens with the highest logits and remove # Top-k (keep only the k tokens with the highest logits and remove
@ -180,8 +182,18 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if top_k > 0: # Top-a (remove all tokens that have softmax probability less than
logits = top_k_filter(logits) # a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits):
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = np.array(jax.nn.softmax(logits), copy=True)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
# Top-p (after sorting the remaining tokens again in descending order of # Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability # logit, remove the ones that have cumulative softmax probability
# greater than p) # greater than p)
@ -207,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if top_p < 1.0:
logits = top_p_filter(logits)
# Tail free sampling (basically top-p a second time on remaining tokens # Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite # except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the # differences of the softmax probabilities" instead of just the
@ -247,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0:
logits = tail_free_filter(logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits): def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them # Compute softmax probabilities and the natural logarithms of them
@ -278,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -jnp.inf, logits) return np.where(indices_to_remove, -jnp.inf, logits)
if typical < 1.0:
logits = typical_filter(logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
logits /= temp def temp_filter(logits):
return logits / temp
for k in sampler_order:
if k == 0 and top_k > 0: logits = top_k_filter(logits)
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
if k == 4 and typical < 1.0: logits = typical_filter(logits)
if k == 5 and temp != 1.0: logits = temp_filter(logits)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
@ -332,10 +346,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array # positions in the logits array
return logits.at[tokens].set(penalty_logits) return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0): def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
''' '''
This gets called by generate_loop_fn to apply a series of 5 filters This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-p, then TFS, then typical, then temperature) to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits before picking one token using the modified logits
''' '''
# Top-k (keep only the k tokens with the highest logits and remove # Top-k (keep only the k tokens with the highest logits and remove
@ -354,7 +368,18 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) # Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits):
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = jax.nn.softmax(logits)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
# Top-p (after sorting the remaining tokens again in descending order of # Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability # logit, remove the ones that have cumulative softmax probability
# greater than p) # greater than p)
@ -380,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
# Tail free sampling (basically top-p a second time on remaining tokens # Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite # except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the # differences of the softmax probabilities" instead of just the
@ -419,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits): def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them # Compute softmax probabilities and the natural logarithms of them
@ -448,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
def temp_filter(logits): def temp_filter(logits):
return logits / temp return logits / temp
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) for k in sampler_order:
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
@ -806,6 +834,7 @@ def infer_static(
top_k=0, top_k=0,
tfs=1.0, tfs=1.0,
typical=1.0, typical=1.0,
top_a=0.0,
repetition_penalty=1.0, repetition_penalty=1.0,
rpslope=0.0, rpslope=0.0,
rprange=0, rprange=0,
@ -813,8 +842,12 @@ def infer_static(
gen_len=80, gen_len=80,
soft_embeddings: Optional[np.array] = None, soft_embeddings: Optional[np.array] = None,
soft_tokens: Optional[np.array] = None, soft_tokens: Optional[np.array] = None,
sampler_order: Optional[List[int]] = None,
) -> List[np.array]: ) -> List[np.array]:
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
if sampler_order is None:
sampler_order = utils.default_sampler_order.copy()
sampler_order = np.uint32(sampler_order)
total_batch = 1 total_batch = 1
tokens = context tokens = context
if(soft_tokens is not None): if(soft_tokens is not None):
@ -825,10 +858,12 @@ def infer_static(
batched_tokens = np.array([padded_tokens] * total_batch) batched_tokens = np.array([padded_tokens] * total_batch)
samples = [] samples = []
batched_generator_params = { batched_generator_params = {
"sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0),
"temp": temp * np.ones(total_batch), "temp": temp * np.ones(total_batch),
"top_p": top_p * np.ones(total_batch), "top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch), "tfs": tfs * np.ones(total_batch),
"typical": typical * np.ones(total_batch), "typical": typical * np.ones(total_batch),
"top_a": top_a * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch), "repetition_penalty": repetition_penalty * np.ones(total_batch),
"rpslope": rpslope * np.ones(total_batch), "rpslope": rpslope * np.ones(total_batch),
"rprange": np.full(total_batch, rprange, dtype=np.uint32), "rprange": np.full(total_batch, rprange, dtype=np.uint32),
@ -985,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params global thread_resources_env, seq, tokenizer, network, params
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
vars.sampler_order = utils.default_sampler_order.copy()
default_params = { default_params = {
"compat": "j", "compat": "j",
"layers": 28, "layers": 28,

View File

@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {} from_pretrained_kwargs = {}
bar = None bar = None
default_sampler_order = [0, 1, 2, 3, 4, 5]
#==================================================================# #==================================================================#
# Decorator to prevent a function's actions from being run until # Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called # at least x seconds have passed without the function being called

View File

@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper):
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value) scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores
class TopALogitsWarper(LogitsWarper):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a)
if top_a < 0 or top_a > 1.0:
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
self.top_a = top_a
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores