Merge pull request #146 from VE-FORBRYDERNE/sampler-order

Add support for changing the order of the samplers
This commit is contained in:
henk717 2022-06-14 14:22:49 +02:00 committed by GitHub
commit 06c3a2a1fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 196 additions and 34 deletions

View File

@ -306,6 +306,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor
sampler_order = utils.default_sampler_order.copy()
chatmode = False chatmode = False
chatname = "You" chatname = "You"
adventure = False adventure = False
@ -567,6 +568,8 @@ def loadmodelsettings():
vars.badwordsids = js["badwordsids"] vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js): if("nobreakmodel" in js):
vars.nobreakmodel = js["nobreakmodel"] vars.nobreakmodel = js["nobreakmodel"]
if("sampler_order" in js):
vars.sampler_order = js["sampler_order"]
if("temp" in js): if("temp" in js):
vars.temp = js["temp"] vars.temp = js["temp"]
if("top_p" in js): if("top_p" in js):
@ -610,6 +613,7 @@ def savesettings():
js = {} js = {}
js["apikey"] = vars.apikey js["apikey"] = vars.apikey
js["andepth"] = vars.andepth js["andepth"] = vars.andepth
js["sampler_order"] = vars.sampler_order
js["temp"] = vars.temp js["temp"] = vars.temp
js["top_p"] = vars.top_p js["top_p"] = vars.top_p
js["top_k"] = vars.top_k js["top_k"] = vars.top_k
@ -686,6 +690,8 @@ def processsettings(js):
vars.apikey = js["apikey"] vars.apikey = js["apikey"]
if("andepth" in js): if("andepth" in js):
vars.andepth = js["andepth"] vars.andepth = js["andepth"]
if("sampler_order" in js):
vars.sampler_order = js["sampler_order"]
if("temp" in js): if("temp" in js):
vars.temp = js["temp"] vars.temp = js["temp"]
if("top_p" in js): if("top_p" in js):
@ -1448,15 +1454,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs):
self.__warper_list: List[LogitsWarper] = []
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
for k in vars.sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
return scores
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
warper_list = LogitsProcessorList() return KoboldLogitsWarperList(beams=beams)
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list
def new_sample(self, *args, **kwargs): def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None assert kwargs.pop("logits_warper", None) is not None
@ -1816,6 +1830,7 @@ else:
def tpumtjgenerate_settings_callback() -> dict: def tpumtjgenerate_settings_callback() -> dict:
return { return {
"sampler_order": vars.sampler_order,
"top_p": float(vars.top_p), "top_p": float(vars.top_p),
"temp": float(vars.temp), "temp": float(vars.temp),
"top_k": int(vars.top_k), "top_k": int(vars.top_k),
@ -2858,6 +2873,8 @@ def get_message(msg):
elif(msg['cmd'] == 'uslistrequest'): elif(msg['cmd'] == 'uslistrequest'):
unloaded, loaded = getuslist() unloaded, loaded = getuslist()
emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}}) emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}})
elif(msg['cmd'] == 'samplerlistrequest'):
emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order})
elif(msg['cmd'] == 'usloaded'): elif(msg['cmd'] == 'usloaded'):
vars.userscripts = [] vars.userscripts = []
for userscript in msg['data']: for userscript in msg['data']:
@ -2871,6 +2888,16 @@ def get_message(msg):
load_lua_scripts() load_lua_scripts()
unloaded, loaded = getuslist() unloaded, loaded = getuslist()
sendUSStatItems() sendUSStatItems()
elif(msg['cmd'] == 'samplers'):
sampler_order = msg["data"]
if(not isinstance(sampler_order, list)):
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
if(len(sampler_order) != len(vars.sampler_order)):
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
if(not all(isinstance(e, int) for e in sampler_order)):
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
vars.sampler_order = sampler_order
settingschanged()
elif(msg['cmd'] == 'loadselect'): elif(msg['cmd'] == 'loadselect'):
vars.loadselect = msg["data"] vars.loadselect = msg["data"]
elif(msg['cmd'] == 'spselect'): elif(msg['cmd'] == 'spselect'):
@ -3910,6 +3937,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
rprange=vars.rep_pen_range, rprange=vars.rep_pen_range,
soft_embeddings=vars.sp, soft_embeddings=vars.sp,
soft_tokens=soft_tokens, soft_tokens=soft_tokens,
sampler_order=vars.sampler_order,
) )
past = genout past = genout
for i in range(vars.numseqs): for i in range(vars.numseqs):

View File

@ -20,6 +20,7 @@ var button_settings;
var button_format; var button_format;
var button_softprompt; var button_softprompt;
var button_userscripts; var button_userscripts;
var button_samplers;
var button_mode; var button_mode;
var button_mode_label; var button_mode_label;
var button_send; var button_send;
@ -109,6 +110,9 @@ var do_clear_ent = false;
// Whether or not an entry in the Userscripts menu is being dragged // Whether or not an entry in the Userscripts menu is being dragged
var us_dragging = false; var us_dragging = false;
// Whether or not an entry in the Samplers menu is being dragged
var samplers_dragging = false;
// Display vars // Display vars
var allowtoggle = false; var allowtoggle = false;
var formatcount = 0; var formatcount = 0;
@ -976,6 +980,16 @@ function hideUSPopup() {
spcontent.html(""); spcontent.html("");
} }
function showSamplersPopup() {
samplerspopup.removeClass("hidden");
samplerspopup.addClass("flex");
}
function hideSamplersPopup() {
samplerspopup.removeClass("flex");
samplerspopup.addClass("hidden");
}
function buildLoadList(ar) { function buildLoadList(ar) {
disableButtons([load_accept]); disableButtons([load_accept]);
loadcontent.html(""); loadcontent.html("");
@ -1109,6 +1123,29 @@ function buildUSList(unloaded, loaded) {
} }
} }
function buildSamplerList(samplers) {
samplerslist.html("");
showSamplersPopup();
var i;
var samplers_lookup_table = [
"Top-k Sampling",
"Top-a Sampling",
"Top-p Sampling",
"Tail-free Sampling",
"Typical Sampling",
"Temperature",
]
for(i=0; i<samplers.length; i++) {
samplerslist.append("<div class=\"flex\">\
<div class=\"samplerslistitem flex-row-container\" sid=\""+samplers[i]+"\">\
<div class=\"flex-row\">\
<div>"+samplers_lookup_table[samplers[i]]+"</div>\
</div>\
</div>\
</div>");
}
}
function highlightLoadLine(ref) { function highlightLoadLine(ref) {
$("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected"); $("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected");
ref.addClass("popuplistselected"); ref.addClass("popuplistselected");
@ -1838,6 +1875,7 @@ $(document).ready(function(){
button_format = $('#btn_format'); button_format = $('#btn_format');
button_softprompt = $("#btn_softprompt"); button_softprompt = $("#btn_softprompt");
button_userscripts= $("#btn_userscripts"); button_userscripts= $("#btn_userscripts");
button_samplers = $("#btn_samplers");
button_mode = $('#btnmode') button_mode = $('#btnmode')
button_mode_label = $('#btnmode_label') button_mode_label = $('#btnmode_label')
button_send = $('#btnsend'); button_send = $('#btnsend');
@ -1886,6 +1924,10 @@ $(document).ready(function(){
usloaded = $("#uslistloaded"); usloaded = $("#uslistloaded");
us_accept = $("#btn_usaccept"); us_accept = $("#btn_usaccept");
us_close = $("#btn_usclose"); us_close = $("#btn_usclose");
samplerspopup = $("#samplerscontainer");
samplerslist = $("#samplerslist");
samplers_accept = $("#btn_samplersaccept");
samplers_close = $("#btn_samplersclose");
nspopup = $("#newgamecontainer"); nspopup = $("#newgamecontainer");
ns_accept = $("#btn_nsaccept"); ns_accept = $("#btn_nsaccept");
ns_close = $("#btn_nsclose"); ns_close = $("#btn_nsclose");
@ -1908,7 +1950,7 @@ $(document).ready(function(){
modelname = msg.modelname; modelname = msg.modelname;
} }
refreshTitle(); refreshTitle();
connect_status.html("<b>Connected to KoboldAI Process!</b>"); connect_status.html("<b>Connected to KoboldAI!</b>");
connect_status.removeClass("color_orange"); connect_status.removeClass("color_orange");
connect_status.addClass("color_green"); connect_status.addClass("color_green");
// Reset Menus // Reset Menus
@ -2310,6 +2352,8 @@ $(document).ready(function(){
buildSPList(msg.data); buildSPList(msg.data);
} else if(msg.cmd == "buildus") { } else if(msg.cmd == "buildus") {
buildUSList(msg.data.unloaded, msg.data.loaded); buildUSList(msg.data.unloaded, msg.data.loaded);
} else if(msg.cmd == "buildsamplers") {
buildSamplerList(msg.data);
} else if(msg.cmd == "askforoverwrite") { } else if(msg.cmd == "askforoverwrite") {
// Show overwrite warning // Show overwrite warning
show([$(".saveasoverwrite")]); show([$(".saveasoverwrite")]);
@ -2436,6 +2480,20 @@ $(document).ready(function(){
}, 10); }, 10);
} }
var samplers_click_handler = function(ev) {
setTimeout(function() {
if (samplers_dragging) {
return;
}
var target = $(ev.target).closest(".samplerslistitem");
var next = target.parent().next().find(".samplerslistitem");
if (!next.length) {
return;
}
next.parent().after(target.parent());
}, 10);
}
// Make the userscripts menu sortable // Make the userscripts menu sortable
var us_sortable_settings = { var us_sortable_settings = {
placeholder: "ussortable-placeholder", placeholder: "ussortable-placeholder",
@ -2456,6 +2514,22 @@ $(document).ready(function(){
connectWith: "#uslistunloaded", connectWith: "#uslistunloaded",
}, us_sortable_settings)).on("click", ".uslistitem", us_click_handler); }, us_sortable_settings)).on("click", ".uslistitem", us_click_handler);
// Make the samplers menu sortable
var samplers_sortable_settings = {
placeholder: "samplerssortable-placeholder",
start: function() { samplers_dragging = true; },
stop: function() { samplers_dragging = false; },
delay: 2,
cursor: "move",
tolerance: "pointer",
opacity: 0.21,
revert: 173,
scrollSensitivity: 64,
scrollSpeed: 10,
}
samplerslist.sortable($.extend({
}, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler);
// Bind actions to UI buttons // Bind actions to UI buttons
button_send.on("click", function(ev) { button_send.on("click", function(ev) {
dosubmit(); dosubmit();
@ -2590,6 +2664,10 @@ $(document).ready(function(){
button_userscripts.on("click", function(ev) { button_userscripts.on("click", function(ev) {
socket.send({'cmd': 'uslistrequest', 'data': ''}); socket.send({'cmd': 'uslistrequest', 'data': ''});
}); });
button_samplers.on("click", function(ev) {
socket.send({'cmd': 'samplerlistrequest', 'data': ''});
});
load_close.on("click", function(ev) { load_close.on("click", function(ev) {
hideLoadPopup(); hideLoadPopup();
@ -2623,6 +2701,16 @@ $(document).ready(function(){
socket.send({'cmd': 'usload', 'data': ''}); socket.send({'cmd': 'usload', 'data': ''});
hideUSPopup(); hideUSPopup();
}); });
samplers_close.on("click", function(ev) {
hideSamplersPopup();
});
samplers_accept.on("click", function(ev) {
hideMessage();
socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()});
hideSamplersPopup();
});
button_newgame.on("click", function(ev) { button_newgame.on("click", function(ev) {
if(connected) { if(connected) {

View File

@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available {
overflow-wrap: anywhere; overflow-wrap: anywhere;
} }
#samplerspopup {
width: 300px;
background-color: #262626;
margin-top: 100px;
}
@media (max-width: 768px) {
#samplerspopup {
width: 100%;
background-color: #262626;
margin-top: 100px;
}
}
#samplerslist {
height: 300px;
overflow-y: scroll;
overflow-wrap: anywhere;
}
#nspopup { #nspopup {
width: 350px; width: 350px;
background-color: #262626; background-color: #262626;
@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover {
background-color: #3bf723; background-color: #3bf723;
} }
.ussortable-placeholder { .ussortable-placeholder, .samplerssortable-placeholder {
height: 4px; height: 4px;
background-color: #3bf723; background-color: #3bf723;
} }
@ -1340,7 +1360,7 @@ body.connected .popupfooter, .popupfooter.always-available {
background-color: #688f1f; background-color: #688f1f;
} }
.uslistitem { .uslistitem, .samplerslistitem {
padding: 12px 10px 12px 10px; padding: 12px 10px 12px 10px;
display: flex; display: flex;
flex-grow: 1; flex-grow: 1;
@ -1352,11 +1372,11 @@ body.connected .popupfooter, .popupfooter.always-available {
transition: background-color 0.25s ease-in; transition: background-color 0.25s ease-in;
} }
.uslistitemsub { .uslistitemsub, .samplerslistitemsub {
color: #ba9; color: #ba9;
} }
.uslistitem:hover { .uslistitem:hover, .samplerslistitem:hover {
cursor: move; cursor: move;
background-color: #688f1f; background-color: #688f1f;
} }

View File

@ -9,7 +9,7 @@
<link rel="stylesheet" href="static/bootstrap.min.css"> <link rel="stylesheet" href="static/bootstrap.min.css">
<link rel="stylesheet" href="static/bootstrap-toggle.min.css"> <link rel="stylesheet" href="static/bootstrap-toggle.min.css">
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css"> <link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
<link rel="stylesheet" href="static/custom.css?ver=1.18b"> <link rel="stylesheet" href="static/custom.css?ver=1.18c">
<script src="static/jquery-3.6.0.min.js"></script> <script src="static/jquery-3.6.0.min.js"></script>
<script src="static/jquery-ui.sortable.min.js"></script> <script src="static/jquery-ui.sortable.min.js"></script>
@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script> <script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script> <script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script> <script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.18d"></script> <script src="static/application.js?ver=1.18e"></script>
</head> </head>
<body> <body>
<input type="file" id="remote-save-select" accept="application/json" style="display:none"> <input type="file" id="remote-save-select" accept="application/json" style="display:none">
@ -71,6 +71,9 @@
<li class="nav-item"> <li class="nav-item">
<a class="nav-link" href="#" id="btn_format">Formatting</a> <a class="nav-link" href="#" id="btn_format">Formatting</a>
</li> </li>
<li class="nav-item">
<a class="nav-link" href="#" id="btn_samplers">Samplers</a>
</li>
<li class="nav-item"> <li class="nav-item">
<a class="nav-link" href="#" id="btn_userscripts">Userscripts</a> <a class="nav-link" href="#" id="btn_userscripts">Userscripts</a>
</li> </li>
@ -299,6 +302,19 @@
</div> </div>
</div> </div>
</div> </div>
<div class="popupcontainer hidden" id="samplerscontainer">
<div id="samplerspopup">
<div class="popuptitlebar">
<div class="popuptitletext">Drag-and-drop to change the order in which the samplers are applied</div>
</div>
<div id="samplerslist">
</div>
<div class="popupfooter">
<button type="button" class="btn btn-primary" id="btn_samplersaccept">Save</button>
<button type="button" class="btn btn-primary" id="btn_samplersclose">Cancel</button>
</div>
</div>
</div>
<div class="popupcontainer hidden" id="loadcontainerdelete"> <div class="popupcontainer hidden" id="loadcontainerdelete">
<div id="loadpopupdelete"> <div id="loadpopupdelete">
<div class="popuptitlebar"> <div class="popuptitlebar">

View File

@ -65,6 +65,7 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List
def settings_callback() -> dict: def settings_callback() -> dict:
return { return {
"sampler_order": utils.default_sampler_order.copy(),
"top_p": 0.9, "top_p": 0.9,
"temp": 0.5, "temp": 0.5,
"top_k": 0, "top_k": 0,
@ -159,7 +160,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits logits[tokens] = penalty_logits
return logits return logits
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
''' '''
This gets called by generate_loop_fn to apply a series of 6 filters This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
@ -181,8 +182,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if top_k > 0:
logits = top_k_filter(logits)
# Top-a (remove all tokens that have softmax probability less than # Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability) # a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits): def top_a_filter(logits):
@ -195,8 +194,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
probs_max = probabilities.max() probs_max = probabilities.max()
# Remove tokens # Remove tokens
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
if top_a > 0.0:
logits = top_a_filter(logits)
# Top-p (after sorting the remaining tokens again in descending order of # Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability # logit, remove the ones that have cumulative softmax probability
# greater than p) # greater than p)
@ -222,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if top_p < 1.0:
logits = top_p_filter(logits)
# Tail free sampling (basically top-p a second time on remaining tokens # Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite # except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the # differences of the softmax probabilities" instead of just the
@ -262,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -np.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0:
logits = tail_free_filter(logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits): def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them # Compute softmax probabilities and the natural logarithms of them
@ -293,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return np.where(indices_to_remove, -jnp.inf, logits) return np.where(indices_to_remove, -jnp.inf, logits)
if typical < 1.0:
logits = typical_filter(logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
logits /= temp def temp_filter(logits):
return logits / temp
for k in sampler_order:
if k == 0 and top_k > 0: logits = top_k_filter(logits)
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
if k == 4 and typical < 1.0: logits = typical_filter(logits)
if k == 5 and temp != 1.0: logits = temp_filter(logits)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
@ -347,7 +346,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array # positions in the logits array
return logits.at[tokens].set(penalty_logits) return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
''' '''
This gets called by generate_loop_fn to apply a series of 6 filters This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
@ -369,7 +368,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
# Top-a (remove all tokens that have softmax probability less than # Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability) # a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits): def top_a_filter(logits):
@ -382,7 +380,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
probs_max = probabilities.max() probs_max = probabilities.max()
# Remove tokens # Remove tokens
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits)
# Top-p (after sorting the remaining tokens again in descending order of # Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability # logit, remove the ones that have cumulative softmax probability
# greater than p) # greater than p)
@ -408,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
# Tail free sampling (basically top-p a second time on remaining tokens # Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite # except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the # differences of the softmax probabilities" instead of just the
@ -447,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits): def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them # Compute softmax probabilities and the natural logarithms of them
@ -476,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
def temp_filter(logits): def temp_filter(logits):
return logits / temp return logits / temp
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) for k in sampler_order:
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
@ -842,8 +842,12 @@ def infer_static(
gen_len=80, gen_len=80,
soft_embeddings: Optional[np.array] = None, soft_embeddings: Optional[np.array] = None,
soft_tokens: Optional[np.array] = None, soft_tokens: Optional[np.array] = None,
sampler_order: Optional[List[int]] = None,
) -> List[np.array]: ) -> List[np.array]:
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
if sampler_order is None:
sampler_order = utils.default_sampler_order.copy()
sampler_order = np.uint32(sampler_order)
total_batch = 1 total_batch = 1
tokens = context tokens = context
if(soft_tokens is not None): if(soft_tokens is not None):
@ -854,6 +858,7 @@ def infer_static(
batched_tokens = np.array([padded_tokens] * total_batch) batched_tokens = np.array([padded_tokens] * total_batch)
samples = [] samples = []
batched_generator_params = { batched_generator_params = {
"sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0),
"temp": temp * np.ones(total_batch), "temp": temp * np.ones(total_batch),
"top_p": top_p * np.ones(total_batch), "top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch), "tfs": tfs * np.ones(total_batch),
@ -1015,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params global thread_resources_env, seq, tokenizer, network, params
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
vars.sampler_order = utils.default_sampler_order.copy()
default_params = { default_params = {
"compat": "j", "compat": "j",
"layers": 28, "layers": 28,

View File

@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {} from_pretrained_kwargs = {}
bar = None bar = None
default_sampler_order = [0, 1, 2, 3, 4, 5]
#==================================================================# #==================================================================#
# Decorator to prevent a function's actions from being run until # Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called # at least x seconds have passed without the function being called