mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-28 08:09:30 +01:00
commit
77ae893f4d
24
aiserver.py
24
aiserver.py
@ -154,6 +154,7 @@ class vars:
|
|||||||
top_p = 0.9 # Default generator top_p
|
top_p = 0.9 # Default generator top_p
|
||||||
top_k = 0 # Default generator top_k
|
top_k = 0 # Default generator top_k
|
||||||
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||||
|
typical = 1.0 # Default generator typical sampling threshold
|
||||||
numseqs = 1 # Number of sequences to ask the generator to create
|
numseqs = 1 # Number of sequences to ask the generator to create
|
||||||
gamestarted = False # Whether the game has started (disables UI elements)
|
gamestarted = False # Whether the game has started (disables UI elements)
|
||||||
gamesaved = True # Whether or not current game is saved
|
gamesaved = True # Whether or not current game is saved
|
||||||
@ -499,6 +500,8 @@ def loadmodelsettings():
|
|||||||
vars.top_k = js["top_k"]
|
vars.top_k = js["top_k"]
|
||||||
if("tfs" in js):
|
if("tfs" in js):
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
|
if("typical" in js):
|
||||||
|
vars.typical = js["typical"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
if("rep_pen_slope" in js):
|
if("rep_pen_slope" in js):
|
||||||
@ -534,6 +537,7 @@ def savesettings():
|
|||||||
js["top_p"] = vars.top_p
|
js["top_p"] = vars.top_p
|
||||||
js["top_k"] = vars.top_k
|
js["top_k"] = vars.top_k
|
||||||
js["tfs"] = vars.tfs
|
js["tfs"] = vars.tfs
|
||||||
|
js["typical"] = vars.typical
|
||||||
js["rep_pen"] = vars.rep_pen
|
js["rep_pen"] = vars.rep_pen
|
||||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||||
js["rep_pen_range"] = vars.rep_pen_range
|
js["rep_pen_range"] = vars.rep_pen_range
|
||||||
@ -600,6 +604,8 @@ def loadsettings():
|
|||||||
vars.top_k = js["top_k"]
|
vars.top_k = js["top_k"]
|
||||||
if("tfs" in js):
|
if("tfs" in js):
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
|
if("typical" in js):
|
||||||
|
vars.typical = js["typical"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
if("rep_pen_slope" in js):
|
if("rep_pen_slope" in js):
|
||||||
@ -1172,7 +1178,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
|
|
||||||
# Patch transformers to use our custom logit warpers
|
# Patch transformers to use our custom logit warpers
|
||||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
|
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
|
||||||
|
|
||||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||||
old_call = cls.__call__
|
old_call = cls.__call__
|
||||||
@ -1194,6 +1200,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||||
|
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||||
@ -1239,6 +1246,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||||
return warper_list
|
return warper_list
|
||||||
|
|
||||||
@ -1540,6 +1548,7 @@ else:
|
|||||||
"temp": float(vars.temp),
|
"temp": float(vars.temp),
|
||||||
"top_k": int(vars.top_k),
|
"top_k": int(vars.top_k),
|
||||||
"tfs": float(vars.tfs),
|
"tfs": float(vars.tfs),
|
||||||
|
"typical": float(vars.typical),
|
||||||
"repetition_penalty": float(vars.rep_pen),
|
"repetition_penalty": float(vars.rep_pen),
|
||||||
"rpslope": float(vars.rep_pen_slope),
|
"rpslope": float(vars.rep_pen_slope),
|
||||||
"rprange": int(vars.rep_pen_range),
|
"rprange": int(vars.rep_pen_range),
|
||||||
@ -1901,6 +1910,7 @@ def lua_has_setting(setting):
|
|||||||
"settopp",
|
"settopp",
|
||||||
"settopk",
|
"settopk",
|
||||||
"settfs",
|
"settfs",
|
||||||
|
"settypical",
|
||||||
"setreppen",
|
"setreppen",
|
||||||
"setreppenslope",
|
"setreppenslope",
|
||||||
"setreppenrange",
|
"setreppenrange",
|
||||||
@ -1919,6 +1929,7 @@ def lua_has_setting(setting):
|
|||||||
"topk",
|
"topk",
|
||||||
"top_k",
|
"top_k",
|
||||||
"tfs",
|
"tfs",
|
||||||
|
"typical",
|
||||||
"reppen",
|
"reppen",
|
||||||
"reppenslope",
|
"reppenslope",
|
||||||
"reppenrange",
|
"reppenrange",
|
||||||
@ -1952,6 +1963,7 @@ def lua_get_setting(setting):
|
|||||||
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
|
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
|
||||||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||||
|
if(setting in ("settypical", "typical")): return vars.typical
|
||||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||||
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||||
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||||
@ -1986,6 +1998,7 @@ def lua_set_setting(setting, v):
|
|||||||
if(setting in ("settopp", "topp")): vars.top_p = v
|
if(setting in ("settopp", "topp")): vars.top_p = v
|
||||||
if(setting in ("settopk", "topk")): vars.top_k = v
|
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||||
|
if(setting in ("settypical", "typical")): vars.typical = v
|
||||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||||
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||||
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||||
@ -2382,6 +2395,11 @@ def get_message(msg):
|
|||||||
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True)
|
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True)
|
||||||
settingschanged()
|
settingschanged()
|
||||||
refresh_settings()
|
refresh_settings()
|
||||||
|
elif(msg['cmd'] == 'settypical'):
|
||||||
|
vars.typical = float(msg['data'])
|
||||||
|
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
|
||||||
|
settingschanged()
|
||||||
|
refresh_settings()
|
||||||
elif(msg['cmd'] == 'setreppen'):
|
elif(msg['cmd'] == 'setreppen'):
|
||||||
vars.rep_pen = float(msg['data'])
|
vars.rep_pen = float(msg['data'])
|
||||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||||
@ -3442,6 +3460,7 @@ def sendtocolab(txt, min, max):
|
|||||||
'top_p': vars.top_p,
|
'top_p': vars.top_p,
|
||||||
'top_k': vars.top_k,
|
'top_k': vars.top_k,
|
||||||
'tfs': vars.tfs,
|
'tfs': vars.tfs,
|
||||||
|
'typical': vars.typical,
|
||||||
'numseqs': vars.numseqs,
|
'numseqs': vars.numseqs,
|
||||||
'retfultxt': False
|
'retfultxt': False
|
||||||
}
|
}
|
||||||
@ -3578,6 +3597,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
top_p=vars.top_p,
|
top_p=vars.top_p,
|
||||||
top_k=vars.top_k,
|
top_k=vars.top_k,
|
||||||
tfs=vars.tfs,
|
tfs=vars.tfs,
|
||||||
|
typical=vars.typical,
|
||||||
numseqs=vars.numseqs,
|
numseqs=vars.numseqs,
|
||||||
repetition_penalty=vars.rep_pen,
|
repetition_penalty=vars.rep_pen,
|
||||||
rpslope=vars.rep_pen_slope,
|
rpslope=vars.rep_pen_slope,
|
||||||
@ -3763,6 +3783,7 @@ def refresh_settings():
|
|||||||
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||||
|
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||||
@ -4341,6 +4362,7 @@ def oairequest(txt, min, max):
|
|||||||
'top_p': vars.top_p,
|
'top_p': vars.top_p,
|
||||||
'top_k': vars.top_k,
|
'top_k': vars.top_k,
|
||||||
'tfs': vars.tfs,
|
'tfs': vars.tfs,
|
||||||
|
'typical': vars.typical,
|
||||||
'repetition_penalty': vars.rep_pen,
|
'repetition_penalty': vars.rep_pen,
|
||||||
'repetition_penalty_slope': vars.rep_pen_slope,
|
'repetition_penalty_slope': vars.rep_pen_slope,
|
||||||
'repetition_penalty_range': vars.rep_pen_range,
|
'repetition_penalty_range': vars.rep_pen_range,
|
||||||
|
@ -866,6 +866,7 @@ return function(_python, _bridged)
|
|||||||
---@field settopp number
|
---@field settopp number
|
||||||
---@field settopk integer
|
---@field settopk integer
|
||||||
---@field settfs number
|
---@field settfs number
|
||||||
|
---@field settypical number
|
||||||
---@field setreppen number
|
---@field setreppen number
|
||||||
---@field setreppenslope number
|
---@field setreppenslope number
|
||||||
---@field setreppenrange number
|
---@field setreppenrange number
|
||||||
@ -882,6 +883,7 @@ return function(_python, _bridged)
|
|||||||
---@field top_p number
|
---@field top_p number
|
||||||
---@field top_k integer
|
---@field top_k integer
|
||||||
---@field tfs number
|
---@field tfs number
|
||||||
|
---@field typical number
|
||||||
---@field reppen number
|
---@field reppen number
|
||||||
---@field reppenslope number
|
---@field reppenslope number
|
||||||
---@field reppenrange number
|
---@field reppenrange number
|
||||||
|
@ -51,8 +51,19 @@ gensettingstf = [
|
|||||||
"min": 0.0,
|
"min": 0.0,
|
||||||
"max": 1.0,
|
"max": 1.0,
|
||||||
"step": 0.05,
|
"step": 0.05,
|
||||||
"default": 0.0,
|
"default": 1.0,
|
||||||
"tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)"
|
"tooltip": "Alternative sampling method; it is recommended to disable top_p and top_k (set top_p to 1 and top_k to 0) if using this. 0.95 is thought to be a good value. (Put this value on 1 to disable its effect)"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"uitype": "slider",
|
||||||
|
"unit": "float",
|
||||||
|
"label": "Typical Sampling",
|
||||||
|
"id": "settypical",
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.05,
|
||||||
|
"default": 1.0,
|
||||||
|
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"uitype": "slider",
|
"uitype": "slider",
|
||||||
|
@ -2041,6 +2041,10 @@ $(document).ready(function(){
|
|||||||
// Send current tfs value to input
|
// Send current tfs value to input
|
||||||
$("#settfs").val(parseFloat(msg.data));
|
$("#settfs").val(parseFloat(msg.data));
|
||||||
$("#settfscur").html(msg.data);
|
$("#settfscur").html(msg.data);
|
||||||
|
} else if(msg.cmd == "updatetypical") {
|
||||||
|
// Send current typical value to input
|
||||||
|
$("#settypical").val(parseFloat(msg.data));
|
||||||
|
$("#settypicalcur").html(msg.data);
|
||||||
} else if(msg.cmd == "updatereppen") {
|
} else if(msg.cmd == "updatereppen") {
|
||||||
// Send current rep pen value to input
|
// Send current rep pen value to input
|
||||||
$("#setreppen").val(parseFloat(msg.data));
|
$("#setreppen").val(parseFloat(msg.data));
|
||||||
@ -2077,6 +2081,9 @@ $(document).ready(function(){
|
|||||||
} else if(msg.cmd == "setlabeltfs") {
|
} else if(msg.cmd == "setlabeltfs") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#settfscur").html(msg.data);
|
$("#settfscur").html(msg.data);
|
||||||
|
} else if(msg.cmd == "setlabeltypical") {
|
||||||
|
// Update setting label with value from server
|
||||||
|
$("#settypicalcur").html(msg.data);
|
||||||
} else if(msg.cmd == "setlabelreppen") {
|
} else if(msg.cmd == "setlabelreppen") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#setreppencur").html(msg.data);
|
$("#setreppencur").html(msg.data);
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
<script src="static/bootstrap.min.js"></script>
|
<script src="static/bootstrap.min.js"></script>
|
||||||
<script src="static/bootstrap-toggle.min.js"></script>
|
<script src="static/bootstrap-toggle.min.js"></script>
|
||||||
<script src="static/rangy-core.min.js"></script>
|
<script src="static/rangy-core.min.js"></script>
|
||||||
<script src="static/application.js?ver=1.17a"></script>
|
<script src="static/application.js?ver=1.17b"></script>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
|
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
|
||||||
|
@ -67,6 +67,7 @@ def settings_callback() -> dict:
|
|||||||
"temp": 0.5,
|
"temp": 0.5,
|
||||||
"top_k": 0,
|
"top_k": 0,
|
||||||
"tfs": 1.0,
|
"tfs": 1.0,
|
||||||
|
"typical": 1.0,
|
||||||
"repetition_penalty": 1.0,
|
"repetition_penalty": 1.0,
|
||||||
"rpslope": 0.0,
|
"rpslope": 0.0,
|
||||||
"rprange": 0,
|
"rprange": 0,
|
||||||
@ -155,11 +156,11 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
|||||||
logits[tokens] = penalty_logits
|
logits[tokens] = penalty_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||||
picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
# Top-k (keep only the k tokens with the highest logits and remove
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
# the rest, by setting their logits to negative infinity)
|
# the rest, by setting their logits to negative infinity)
|
||||||
@ -246,6 +247,37 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if tfs < 1.0:
|
if tfs < 1.0:
|
||||||
logits = tail_free_filter(logits)
|
logits = tail_free_filter(logits)
|
||||||
|
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||||
|
def typical_filter(logits):
|
||||||
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
probs = jax.nn.softmax(logits)
|
||||||
|
with np.errstate(divide="ignore"):
|
||||||
|
log_probs = np.log(probs)
|
||||||
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
|
# in the set of softmax probabilities of the logits
|
||||||
|
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
||||||
|
# Determine absolute difference between the negative entropy and the
|
||||||
|
# log probabilities
|
||||||
|
entropy_deviation = np.abs(neg_entropy - log_probs)
|
||||||
|
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||||
|
# kept tokens is the smallest possible value such that the sum of the
|
||||||
|
# softmax probabilities of the kept tokens is at least the threshold
|
||||||
|
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||||
|
# and then keeping the smallest possible number of tokens from the
|
||||||
|
# beginning such that sum of softmax probabilities is at or above the
|
||||||
|
# threshold)
|
||||||
|
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||||
|
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
|
||||||
|
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||||
|
sorted_indices_to_remove[0] = False
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(entropy_deviation),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return np.where(indices_to_remove, -jnp.inf, logits)
|
||||||
|
if typical < 1.0:
|
||||||
|
logits = typical_filter(logits)
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
logits /= temp
|
logits /= temp
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
@ -298,11 +330,11 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
|||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
return logits.at[tokens].set(penalty_logits)
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
This gets called by generate_loop_fn to apply a series of 5 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
||||||
picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
# Top-k (keep only the k tokens with the highest logits and remove
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
# the rest, by setting their logits to negative infinity)
|
# the rest, by setting their logits to negative infinity)
|
||||||
@ -386,6 +418,35 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
||||||
|
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||||
|
def typical_filter(logits):
|
||||||
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
probs = jax.nn.softmax(logits)
|
||||||
|
log_probs = jnp.log(probs)
|
||||||
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
|
# in the set of softmax probabilities of the logits
|
||||||
|
neg_entropy = (probs * log_probs).sum(axis=-1, keepdims=True)
|
||||||
|
# Determine absolute difference between the negative entropy and the
|
||||||
|
# log probabilities
|
||||||
|
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||||
|
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||||
|
# kept tokens is the smallest possible value such that the sum of the
|
||||||
|
# softmax probabilities of the kept tokens is at least the threshold
|
||||||
|
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||||
|
# and then keeping the smallest possible number of tokens from the
|
||||||
|
# beginning such that sum of softmax probabilities is at or above the
|
||||||
|
# threshold)
|
||||||
|
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||||
|
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
|
||||||
|
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
|
# Unsort and remove
|
||||||
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
|
jnp.argsort(entropy_deviation),
|
||||||
|
sorted_indices_to_remove,
|
||||||
|
)
|
||||||
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
|
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
def temp_filter(logits):
|
def temp_filter(logits):
|
||||||
return logits / temp
|
return logits / temp
|
||||||
@ -742,6 +803,7 @@ def infer_static(
|
|||||||
temp=0.5,
|
temp=0.5,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
tfs=1.0,
|
tfs=1.0,
|
||||||
|
typical=1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
rpslope=0.0,
|
rpslope=0.0,
|
||||||
rprange=0,
|
rprange=0,
|
||||||
@ -764,6 +826,7 @@ def infer_static(
|
|||||||
"temp": temp * np.ones(total_batch),
|
"temp": temp * np.ones(total_batch),
|
||||||
"top_p": top_p * np.ones(total_batch),
|
"top_p": top_p * np.ones(total_batch),
|
||||||
"tfs": tfs * np.ones(total_batch),
|
"tfs": tfs * np.ones(total_batch),
|
||||||
|
"typical": typical * np.ones(total_batch),
|
||||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||||
"rpslope": rpslope * np.ones(total_batch),
|
"rpslope": rpslope * np.ones(total_batch),
|
||||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||||
|
52
warpers.py
52
warpers.py
@ -62,7 +62,7 @@ class TailFreeLogitsWarper(LogitsWarper):
|
|||||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
tfs = float(tfs)
|
tfs = float(tfs)
|
||||||
if tfs < 0 or tfs > 1.0:
|
if tfs < 0 or tfs > 1.0:
|
||||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
||||||
self.tfs = tfs
|
self.tfs = tfs
|
||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
@ -98,3 +98,53 @@ class TailFreeLogitsWarper(LogitsWarper):
|
|||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TypicalLogitsWarper(LogitsWarper):
|
||||||
|
'''
|
||||||
|
Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, typical: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
typical = float(typical)
|
||||||
|
if typical < 0 or typical > 1.0:
|
||||||
|
raise ValueError(f"`typical` has to be a float >= 0 and <= 1, but is {typical}")
|
||||||
|
self.typical = typical
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.filter_value >= 1.0:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
probs = scores.softmax(dim=-1)
|
||||||
|
log_probs = probs.log()
|
||||||
|
|
||||||
|
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||||
|
# in the set of softmax probabilities of the logits
|
||||||
|
neg_entropy = (probs * log_probs).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Determine absolute difference between the negative entropy and the
|
||||||
|
# log probabilities
|
||||||
|
entropy_deviation = (neg_entropy - log_probs).abs()
|
||||||
|
|
||||||
|
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||||
|
# kept tokens is the smallest possible value such that the sum of the
|
||||||
|
# softmax probabilities of the kept tokens is at least the threshold
|
||||||
|
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||||
|
# and then keeping the smallest possible number of tokens from the
|
||||||
|
# beginning such that sum of softmax probabilities is at or above the
|
||||||
|
# threshold)
|
||||||
|
_, sorted_indices = torch.sort(entropy_deviation)
|
||||||
|
sorted_logits = probs.gather(-1, sorted_indices)
|
||||||
|
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= self.typical
|
||||||
|
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
|
||||||
|
|
||||||
|
min_tokens_to_keep = max(self.min_tokens_to_keep, 1)
|
||||||
|
# Keep at least min_tokens_to_keep
|
||||||
|
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
Loading…
x
Reference in New Issue
Block a user