Merge pull request #148 from VE-FORBRYDERNE/overhaul-merge

Merge united into overhaul
This commit is contained in:
henk717 2022-06-15 00:56:14 +02:00 committed by GitHub
commit f3eb7cba5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 298 additions and 36 deletions

View File

@ -106,7 +106,6 @@ model_menu = {
["Adventure Models", "adventurelist", "", True],
["Novel Models", "novellist", "", True],
["NSFW Models", "nsfwlist", "", True],
["Chatbot Models", "chatlist", "", True],
["Untuned GPT-Neo/J", "gptneolist", "", True],
["Untuned Fairseq Dense", "fsdlist", "", True],
["Untuned OPT", "optlist", "", True],
@ -220,6 +219,7 @@ class vars:
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
top_a = 0.0 # Default generator top-a
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
@ -315,6 +315,7 @@ class vars:
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor
sampler_order = utils.default_sampler_order.copy()
chatmode = False
chatname = "You"
adventure = False
@ -647,6 +648,8 @@ def loadmodelsettings():
vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js):
vars.nobreakmodel = js["nobreakmodel"]
if("sampler_order" in js):
vars.sampler_order = js["sampler_order"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
@ -657,6 +660,8 @@ def loadmodelsettings():
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -688,11 +693,13 @@ def savesettings():
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["sampler_order"] = vars.sampler_order
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["typical"] = vars.typical
js["top_a"] = vars.top_a
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
@ -763,6 +770,8 @@ def processsettings(js):
vars.apikey = js["apikey"]
if("andepth" in js):
vars.andepth = js["andepth"]
if("sampler_order" in js):
vars.sampler_order = js["sampler_order"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
@ -773,6 +782,8 @@ def processsettings(js):
vars.tfs = js["tfs"]
if("typical" in js):
vars.typical = js["typical"]
if("top_a" in js):
vars.top_a = js["top_a"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
@ -1268,7 +1279,7 @@ def patch_transformers():
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
@ -1288,6 +1299,7 @@ def patch_transformers():
cls.__call__ = new_call
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
@ -1331,14 +1343,23 @@ def patch_transformers():
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs):
self.__warper_list: List[LogitsWarper] = []
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
for k in vars.sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
return scores
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
warper_list = LogitsProcessorList()
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list
return KoboldLogitsWarperList(beams=beams)
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
@ -1957,11 +1978,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
def tpumtjgenerate_settings_callback() -> dict:
return {
"sampler_order": vars.sampler_order,
"top_p": float(vars.top_p),
"temp": float(vars.temp),
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"typical": float(vars.typical),
"top_a": float(vars.top_a),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
@ -2384,6 +2407,7 @@ def lua_has_setting(setting):
"settopk",
"settfs",
"settypical",
"settopa",
"setreppen",
"setreppenslope",
"setreppenrange",
@ -2403,6 +2427,7 @@ def lua_has_setting(setting):
"top_k",
"tfs",
"typical",
"topa",
"reppen",
"reppenslope",
"reppenrange",
@ -2437,6 +2462,7 @@ def lua_get_setting(setting):
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("settypical", "typical")): return vars.typical
if(setting in ("settopa", "topa")): return vars.top_a
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
@ -2472,6 +2498,7 @@ def lua_set_setting(setting, v):
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("settypical", "typical")): vars.typical = v
if(setting in ("settopa", "topa")): vars.top_a = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
@ -2862,6 +2889,11 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'settopa'):
vars.top_a = float(msg['data'])
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppen'):
vars.rep_pen = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
@ -3015,6 +3047,8 @@ def get_message(msg):
elif(msg['cmd'] == 'uslistrequest'):
unloaded, loaded = getuslist()
emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}})
elif(msg['cmd'] == 'samplerlistrequest'):
emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order})
elif(msg['cmd'] == 'usloaded'):
vars.userscripts = []
for userscript in msg['data']:
@ -3028,6 +3062,16 @@ def get_message(msg):
load_lua_scripts()
unloaded, loaded = getuslist()
sendUSStatItems()
elif(msg['cmd'] == 'samplers'):
sampler_order = msg["data"]
if(not isinstance(sampler_order, list)):
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
if(len(sampler_order) != len(vars.sampler_order)):
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
if(not all(isinstance(e, int) for e in sampler_order)):
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
vars.sampler_order = sampler_order
settingschanged()
elif(msg['cmd'] == 'list_model'):
sendModelSelection(menu=msg['data'])
elif(msg['cmd'] == 'load_model'):
@ -3988,6 +4032,7 @@ def sendtocolab(txt, min, max):
'top_k': vars.top_k,
'tfs': vars.tfs,
'typical': vars.typical,
'topa': vars.top_a,
'numseqs': vars.numseqs,
'retfultxt': False
}
@ -4125,12 +4170,14 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
top_k=vars.top_k,
tfs=vars.tfs,
typical=vars.typical,
top_a=vars.top_a,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
rprange=vars.rep_pen_range,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
sampler_order=vars.sampler_order,
)
past = genout
for i in range(vars.numseqs):
@ -4311,6 +4358,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
@ -4887,6 +4935,7 @@ def oairequest(txt, min, max):
'prompt': txt,
'max_tokens': vars.genamt,
'temperature': vars.temp,
'top_a': vars.top_a,
'top_p': vars.top_p,
'top_k': vars.top_k,
'tfs': vars.tfs,

View File

@ -867,6 +867,7 @@ return function(_python, _bridged)
---@field settopk integer
---@field settfs number
---@field settypical number
---@field settopa number
---@field setreppen number
---@field setreppenslope number
---@field setreppenrange number
@ -884,6 +885,7 @@ return function(_python, _bridged)
---@field top_k integer
---@field tfs number
---@field typical number
---@field topa number
---@field reppen number
---@field reppenslope number
---@field reppenrange number

View File

@ -64,6 +64,17 @@ gensettingstf = [
"step": 0.05,
"default": 1.0,
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
},
{
"uitype": "slider",
"unit": "float",
"label": "Top a Sampling",
"id": "settopa",
"min": 0.0,
"max": 1.0,
"step": 0.01,
"default": 0.0,
"tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect."
},
{
"uitype": "slider",

View File

@ -21,6 +21,7 @@ var button_settings;
var button_format;
var button_softprompt;
var button_userscripts;
var button_samplers;
var button_mode;
var button_mode_label;
var button_send;
@ -112,6 +113,9 @@ var do_clear_ent = false;
// Whether or not an entry in the Userscripts menu is being dragged
var us_dragging = false;
// Whether or not an entry in the Samplers menu is being dragged
var samplers_dragging = false;
// Display vars
var allowtoggle = false;
var formatcount = 0;
@ -997,6 +1001,16 @@ function hideUSPopup() {
spcontent.html("");
}
function showSamplersPopup() {
samplerspopup.removeClass("hidden");
samplerspopup.addClass("flex");
}
function hideSamplersPopup() {
samplerspopup.removeClass("flex");
samplerspopup.addClass("hidden");
}
function buildLoadModelList(ar, menu, breadcrumbs) {
disableButtons([load_model_accept]);
@ -1207,6 +1221,29 @@ function buildUSList(unloaded, loaded) {
}
}
function buildSamplerList(samplers) {
samplerslist.html("");
showSamplersPopup();
var i;
var samplers_lookup_table = [
"Top-k Sampling",
"Top-a Sampling",
"Top-p Sampling",
"Tail-free Sampling",
"Typical Sampling",
"Temperature",
]
for(i=0; i<samplers.length; i++) {
samplerslist.append("<div class=\"flex\">\
<div class=\"samplerslistitem flex-row-container\" sid=\""+samplers[i]+"\">\
<div class=\"flex-row\">\
<div>"+samplers_lookup_table[samplers[i]]+"</div>\
</div>\
</div>\
</div>");
}
}
function highlightLoadLine(ref) {
$("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected");
$("#loadmodellistcontent > div > div.popuplistselected").removeClass("popuplistselected");
@ -1963,6 +2000,7 @@ $(document).ready(function(){
button_format = $('#btn_format');
button_softprompt = $("#btn_softprompt");
button_userscripts= $("#btn_userscripts");
button_samplers = $("#btn_samplers");
button_mode = $('#btnmode')
button_mode_label = $('#btnmode_label')
button_send = $('#btnsend');
@ -2015,6 +2053,10 @@ $(document).ready(function(){
usloaded = $("#uslistloaded");
us_accept = $("#btn_usaccept");
us_close = $("#btn_usclose");
samplerspopup = $("#samplerscontainer");
samplerslist = $("#samplerslist");
samplers_accept = $("#btn_samplersaccept");
samplers_close = $("#btn_samplersclose");
nspopup = $("#newgamecontainer");
ns_accept = $("#btn_nsaccept");
ns_close = $("#btn_nsclose");
@ -2038,7 +2080,7 @@ $(document).ready(function(){
modelname = msg.modelname;
}
refreshTitle();
connect_status.html("<b>Connected to KoboldAI Process!</b>");
connect_status.html("<b>Connected to KoboldAI!</b>");
connect_status.removeClass("color_orange");
connect_status.addClass("color_green");
// Reset Menus
@ -2231,6 +2273,10 @@ $(document).ready(function(){
// Send current typical value to input
$("#settypicalcur").val(msg.data);
$("#settypical").val(parseFloat(msg.data)).trigger("change");
} else if(msg.cmd == "updatetopa") {
// Send current top a value to input
$("#settopacur").val(msg.data);
$("#settopa").val(parseFloat(msg.data)).trigger("change");
} else if(msg.cmd == "updatereppen") {
// Send current rep pen value to input
$("#setreppencur").val(msg.data);
@ -2270,6 +2316,9 @@ $(document).ready(function(){
} else if(msg.cmd == "setlabeltypical") {
// Update setting label with value from server
$("#settypicalcur").val(msg.data);
} else if(msg.cmd == "setlabeltypical") {
// Update setting label with value from server
$("#settopa").val(msg.data);
} else if(msg.cmd == "setlabelreppen") {
// Update setting label with value from server
$("#setreppencur").val(msg.data);
@ -2440,6 +2489,8 @@ $(document).ready(function(){
buildSPList(msg.data);
} else if(msg.cmd == "buildus") {
buildUSList(msg.data.unloaded, msg.data.loaded);
} else if(msg.cmd == "buildsamplers") {
buildSamplerList(msg.data);
} else if(msg.cmd == "askforoverwrite") {
// Show overwrite warning
show([$(".saveasoverwrite")]);
@ -2648,6 +2699,20 @@ $(document).ready(function(){
}, 10);
}
var samplers_click_handler = function(ev) {
setTimeout(function() {
if (samplers_dragging) {
return;
}
var target = $(ev.target).closest(".samplerslistitem");
var next = target.parent().next().find(".samplerslistitem");
if (!next.length) {
return;
}
next.parent().after(target.parent());
}, 10);
}
// Make the userscripts menu sortable
var us_sortable_settings = {
placeholder: "ussortable-placeholder",
@ -2668,6 +2733,22 @@ $(document).ready(function(){
connectWith: "#uslistunloaded",
}, us_sortable_settings)).on("click", ".uslistitem", us_click_handler);
// Make the samplers menu sortable
var samplers_sortable_settings = {
placeholder: "samplerssortable-placeholder",
start: function() { samplers_dragging = true; },
stop: function() { samplers_dragging = false; },
delay: 2,
cursor: "move",
tolerance: "pointer",
opacity: 0.21,
revert: 173,
scrollSensitivity: 64,
scrollSpeed: 10,
}
samplerslist.sortable($.extend({
}, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler);
// Bind actions to UI buttons
button_send.on("click", function(ev) {
dosubmit();
@ -2802,6 +2883,10 @@ $(document).ready(function(){
button_userscripts.on("click", function(ev) {
socket.send({'cmd': 'uslistrequest', 'data': ''});
});
button_samplers.on("click", function(ev) {
socket.send({'cmd': 'samplerlistrequest', 'data': ''});
});
load_close.on("click", function(ev) {
hideLoadPopup();
@ -2858,6 +2943,16 @@ $(document).ready(function(){
socket.send({'cmd': 'usload', 'data': ''});
hideUSPopup();
});
samplers_close.on("click", function(ev) {
hideSamplersPopup();
});
samplers_accept.on("click", function(ev) {
hideMessage();
socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()});
hideSamplersPopup();
});
button_loadmodel.on("click", function(ev) {
showLoadModelPopup();

View File

@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available {
overflow-wrap: anywhere;
}
#samplerspopup {
width: 300px;
background-color: #262626;
margin-top: 100px;
}
@media (max-width: 768px) {
#samplerspopup {
width: 100%;
background-color: #262626;
margin-top: 100px;
}
}
#samplerslist {
height: 300px;
overflow-y: scroll;
overflow-wrap: anywhere;
}
#nspopup {
width: 350px;
background-color: #262626;
@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover {
background-color: #3bf723;
}
.ussortable-placeholder {
.ussortable-placeholder, .samplerssortable-placeholder {
height: 4px;
background-color: #3bf723;
}
@ -1362,7 +1382,7 @@ body.connected .popupfooter, .popupfooter.always-available {
background-color: #688f1f;
}
.uslistitem {
.uslistitem, .samplerslistitem {
padding: 12px 10px 12px 10px;
display: flex;
flex-grow: 1;
@ -1374,11 +1394,11 @@ body.connected .popupfooter, .popupfooter.always-available {
transition: background-color 0.25s ease-in;
}
.uslistitemsub {
.uslistitemsub, .samplerslistitemsub {
color: #ba9;
}
.uslistitem:hover {
.uslistitem:hover, .samplerslistitem:hover {
cursor: move;
background-color: #688f1f;
}

View File

@ -9,7 +9,7 @@
<link rel="stylesheet" href="static/bootstrap.min.css">
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
<link rel="stylesheet" href="static/custom.css?ver=1.18b">
<link rel="stylesheet" href="static/custom.css?ver=1.18c">
<script src="static/jquery-3.6.0.min.js"></script>
<script src="static/jquery-ui.sortable.min.js"></script>
@ -17,7 +17,7 @@
<script src="static/bootstrap.min.js"></script>
<script src="static/bootstrap-toggle.min.js"></script>
<script src="static/rangy-core.min.js"></script>
<script src="static/application.js?ver=1.18c"></script>
<script src="static/application.js?ver=1.18e"></script>
<script src="static/favicon.js"></script>
</head>
<body>
@ -81,6 +81,9 @@
<li class="nav-item">
<a class="nav-link" href="#" id="btn_format">Formatting</a>
</li>
<li class="nav-item">
<a class="nav-link" href="#" id="btn_samplers">Samplers</a>
</li>
<li class="nav-item">
<a class="nav-link" href="#" id="btn_userscripts">Userscripts</a>
</li>
@ -363,6 +366,19 @@
</div>
</div>
</div>
<div class="popupcontainer hidden" id="samplerscontainer">
<div id="samplerspopup">
<div class="popuptitlebar">
<div class="popuptitletext">Drag-and-drop to change the order in which the samplers are applied</div>
</div>
<div id="samplerslist">
</div>
<div class="popupfooter">
<button type="button" class="btn btn-primary" id="btn_samplersaccept">Save</button>
<button type="button" class="btn btn-primary" id="btn_samplersclose">Cancel</button>
</div>
</div>
</div>
<div class="popupcontainer hidden" id="loadcontainerdelete">
<div id="loadpopupdelete">
<div class="popuptitlebar">

View File

@ -65,11 +65,13 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List
def settings_callback() -> dict:
return {
"sampler_order": utils.default_sampler_order.copy(),
"top_p": 0.9,
"temp": 0.5,
"top_k": 0,
"tfs": 1.0,
"typical": 1.0,
"top_a": 0.0,
"repetition_penalty": 1.0,
"rpslope": 0.0,
"rprange": 0,
@ -158,10 +160,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits
return logits
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
'''
This gets called by generate_loop_fn to apply a series of 5 filters
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
# Top-k (keep only the k tokens with the highest logits and remove
@ -180,8 +182,18 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, logits)
if top_k > 0:
logits = top_k_filter(logits)
# Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits):
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = np.array(jax.nn.softmax(logits), copy=True)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
# Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability
# greater than p)
@ -207,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, logits)
if top_p < 1.0:
logits = top_p_filter(logits)
# Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the
@ -247,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0:
logits = tail_free_filter(logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
@ -278,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -jnp.inf, logits)
if typical < 1.0:
logits = typical_filter(logits)
# Temperature (just divide the logits by the temperature)
logits /= temp
def temp_filter(logits):
return logits / temp
for k in sampler_order:
if k == 0 and top_k > 0: logits = top_k_filter(logits)
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
if k == 4 and typical < 1.0: logits = typical_filter(logits)
if k == 5 and temp != 1.0: logits = temp_filter(logits)
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
@ -332,10 +346,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
'''
This gets called by generate_loop_fn to apply a series of 5 filters
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
# Top-k (keep only the k tokens with the highest logits and remove
@ -354,7 +368,18 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
# Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits):
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = jax.nn.softmax(logits)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
# Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability
# greater than p)
@ -380,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
# Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the
@ -419,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
@ -448,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
# Temperature (just divide the logits by the temperature)
def temp_filter(logits):
return logits / temp
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
for k in sampler_order:
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
@ -806,6 +834,7 @@ def infer_static(
top_k=0,
tfs=1.0,
typical=1.0,
top_a=0.0,
repetition_penalty=1.0,
rpslope=0.0,
rprange=0,
@ -813,8 +842,12 @@ def infer_static(
gen_len=80,
soft_embeddings: Optional[np.array] = None,
soft_tokens: Optional[np.array] = None,
sampler_order: Optional[List[int]] = None,
) -> List[np.array]:
maps.thread_resources.env = thread_resources_env
if sampler_order is None:
sampler_order = utils.default_sampler_order.copy()
sampler_order = np.uint32(sampler_order)
total_batch = 1
tokens = context
if(soft_tokens is not None):
@ -825,10 +858,12 @@ def infer_static(
batched_tokens = np.array([padded_tokens] * total_batch)
samples = []
batched_generator_params = {
"sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0),
"temp": temp * np.ones(total_batch),
"top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch),
"typical": typical * np.ones(total_batch),
"top_a": top_a * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch),
"rpslope": rpslope * np.ones(total_batch),
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
@ -985,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
vars.sampler_order = utils.default_sampler_order.copy()
default_params = {
"compat": "j",
"layers": 28,

View File

@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {}
bar = None
default_sampler_order = [0, 1, 2, 3, 4, 5]
#==================================================================#
# Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called

View File

@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper):
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TopALogitsWarper(LogitsWarper):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a)
if top_a < 0 or top_a > 1.0:
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
self.top_a = top_a
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores