Merge pull request #72 from VE-FORBRYDERNE/rep-pen
Repetition penalty slope and range
This commit is contained in:
commit
392c59d48b
100
aiserver.py
100
aiserver.py
|
@ -23,6 +23,7 @@ import packaging
|
|||
import contextlib
|
||||
import traceback
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
|
@ -100,6 +101,8 @@ class vars:
|
|||
genamt = 80 # Amount of text for each action to generate
|
||||
ikgen = 200 # Number of characters for InferKit to generate
|
||||
rep_pen = 1.1 # Default generator repetition_penalty
|
||||
rep_pen_slope = 0.0 # Default generator repetition penalty slope
|
||||
rep_pen_range = 0 # Default generator repetition penalty range
|
||||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
|
@ -696,65 +699,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||
old_call = cls.__call__
|
||||
def new_call(self, *args, **kwargs):
|
||||
setattr(self, field_name, getattr(vars, var_name))
|
||||
if(not isinstance(field_name, str) and isinstance(field_name, Iterable)):
|
||||
conds = []
|
||||
for f, v in zip(field_name, var_name):
|
||||
conds.append(getattr(vars, v))
|
||||
setattr(self, f, conds[-1])
|
||||
else:
|
||||
conds = getattr(vars, var_name)
|
||||
setattr(self, field_name, conds)
|
||||
assert len(args) == 2
|
||||
if(cond is None or cond(getattr(vars, var_name))):
|
||||
if(cond is None or cond(conds)):
|
||||
return old_call(self, *args, **kwargs)
|
||||
return args[1]
|
||||
cls.__call__ = new_call
|
||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
|
||||
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self.tfs = vars.tfs
|
||||
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Compute second derivative normalized CDF
|
||||
d2 = probs.diff().diff().abs()
|
||||
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||
|
||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||
sorted_indices_to_remove = torch.cat(
|
||||
(
|
||||
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
sorted_indices_to_remove,
|
||||
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1072,6 +1042,8 @@ else:
|
|||
"top_k": int(vars.top_k),
|
||||
"tfs": float(vars.tfs),
|
||||
"repetition_penalty": float(vars.rep_pen),
|
||||
"rpslope": float(vars.rep_pen_slope),
|
||||
"rprange": int(vars.rep_pen_range),
|
||||
}
|
||||
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
|
@ -1418,6 +1390,8 @@ def lua_has_setting(setting):
|
|||
"settopk",
|
||||
"settfs",
|
||||
"setreppen",
|
||||
"setreppenslope",
|
||||
"setreppenrange",
|
||||
"settknmax",
|
||||
"setwidepth",
|
||||
"setuseprompt",
|
||||
|
@ -1433,6 +1407,8 @@ def lua_has_setting(setting):
|
|||
"top_k",
|
||||
"tfs",
|
||||
"reppen",
|
||||
"reppenslope",
|
||||
"reppenrange",
|
||||
"tknmax",
|
||||
"widepth",
|
||||
"useprompt",
|
||||
|
@ -1464,6 +1440,8 @@ def lua_get_setting(setting):
|
|||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||
if(setting in ("settknmax", "tknmax")): return vars.max_length
|
||||
if(setting == "anotedepth"): return vars.andepth
|
||||
if(setting in ("setwidepth", "widepth")): return vars.widepth
|
||||
|
@ -1495,6 +1473,8 @@ def lua_set_setting(setting, v):
|
|||
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
|
||||
if(setting == "anotedepth"): vars.andepth = v; return True
|
||||
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
|
||||
|
@ -1881,6 +1861,16 @@ def get_message(msg):
|
|||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setreppenslope'):
|
||||
vars.rep_pen_slope = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setreppenrange'):
|
||||
vars.rep_pen_range = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setoutput'):
|
||||
vars.genamt = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
|
||||
|
@ -2151,6 +2141,8 @@ def savesettings():
|
|||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||
js["rep_pen_range"] = vars.rep_pen_range
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
js["ikgen"] = vars.ikgen
|
||||
|
@ -2206,6 +2198,10 @@ def loadsettings():
|
|||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("genamt" in js):
|
||||
vars.genamt = js["genamt"]
|
||||
if("max_length" in js):
|
||||
|
@ -2285,6 +2281,10 @@ def loadmodelsettings():
|
|||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("chatmode" in js):
|
||||
|
@ -2958,6 +2958,8 @@ def sendtocolab(txt, min, max):
|
|||
'min': min,
|
||||
'max': max,
|
||||
'rep_pen': vars.rep_pen,
|
||||
'rep_pen_slope': vars.rep_pen_slope,
|
||||
'rep_pen_range': vars.rep_pen_range,
|
||||
'temperature': vars.temp,
|
||||
'top_p': vars.top_p,
|
||||
'top_k': vars.top_k,
|
||||
|
@ -3099,6 +3101,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
tfs=vars.tfs,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
rpslope=vars.rep_pen_slope,
|
||||
rprange=vars.rep_pen_range,
|
||||
soft_embeddings=vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
)
|
||||
|
@ -3281,6 +3285,8 @@ def refresh_settings():
|
|||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)
|
||||
|
|
|
@ -867,6 +867,8 @@ return function(_python, _bridged)
|
|||
---@field settopk integer
|
||||
---@field settfs number
|
||||
---@field setreppen number
|
||||
---@field setreppenslope number
|
||||
---@field setreppenrange number
|
||||
---@field settknmax integer
|
||||
---@field setwidepth integer
|
||||
---@field setuseprompt boolean
|
||||
|
@ -881,6 +883,8 @@ return function(_python, _bridged)
|
|||
---@field top_k integer
|
||||
---@field tfs number
|
||||
---@field reppen number
|
||||
---@field reppenslope number
|
||||
---@field reppenrange number
|
||||
---@field tknmax integer
|
||||
---@field widepth integer
|
||||
---@field useprompt boolean
|
||||
|
|
|
@ -52,6 +52,28 @@ gensettingstf = [{
|
|||
"step": 0.01,
|
||||
"default": 1.1,
|
||||
"tooltip": "Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models)."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Rep Penalty Range",
|
||||
"id": "setreppenrange",
|
||||
"min": 0,
|
||||
"max": 4096,
|
||||
"step": 4,
|
||||
"default": 0,
|
||||
"tooltip": "Repetition penalty range. If set higher than 0, only applies repetition penalty to the last few tokens of your story rather than applying it to the entire story. This slider controls the amount of tokens at the end of your story to apply it to."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Rep Penalty Slope",
|
||||
"id": "setreppenslope",
|
||||
"min": 0.0,
|
||||
"max": 10.0,
|
||||
"step": 0.1,
|
||||
"default": 0.0,
|
||||
"tooltip": "Repetition penalty slope. If BOTH this setting and Rep Penalty Range are set higher than 0, will use sigmoid interpolation to apply repetition penalty more strongly on tokens that are closer to the end of your story. This setting controls the tension of the sigmoid curve; higher settings will result in the repetition penalty difference between the start and end of your story being more apparent. Setting this to 1 uses linear interpolation; setting this to 0 disables interpolation."
|
||||
},
|
||||
{
|
||||
"uitype": "slider",
|
||||
|
|
|
@ -1991,6 +1991,14 @@ $(document).ready(function(){
|
|||
// Send current rep pen value to input
|
||||
$("#setreppen").val(parseFloat(msg.data));
|
||||
$("#setreppencur").html(msg.data);
|
||||
} else if(msg.cmd == "updatereppenslope") {
|
||||
// Send current rep pen value to input
|
||||
$("#setreppenslope").val(parseFloat(msg.data));
|
||||
$("#setreppenslopecur").html(msg.data);
|
||||
} else if(msg.cmd == "updatereppenrange") {
|
||||
// Send current rep pen value to input
|
||||
$("#setreppenrange").val(parseFloat(msg.data));
|
||||
$("#setreppenrangecur").html(msg.data);
|
||||
} else if(msg.cmd == "updateoutlen") {
|
||||
// Send current output amt value to input
|
||||
$("#setoutput").val(parseInt(msg.data));
|
||||
|
@ -2018,6 +2026,12 @@ $(document).ready(function(){
|
|||
} else if(msg.cmd == "setlabelreppen") {
|
||||
// Update setting label with value from server
|
||||
$("#setreppencur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabelreppenslope") {
|
||||
// Update setting label with value from server
|
||||
$("#setreppenslopecur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabelreppenrange") {
|
||||
// Update setting label with value from server
|
||||
$("#setreppenrangecur").html(msg.data);
|
||||
} else if(msg.cmd == "setlabeloutput") {
|
||||
// Update setting label with value from server
|
||||
$("#setoutputcur").html(msg.data);
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
<script src="static/bootstrap.min.js"></script>
|
||||
<script src="static/bootstrap-toggle.min.js"></script>
|
||||
<script src="static/rangy-core.min.js"></script>
|
||||
<script src="static/application.js?ver=1.16.4ac"></script>
|
||||
<script src="static/application.js?ver=1.17"></script>
|
||||
</head>
|
||||
<body>
|
||||
<input type="file" id="remote-save-select" accept="application/json" style="display:none">
|
||||
|
|
|
@ -1,3 +1,32 @@
|
|||
'''
|
||||
This file is AGPL-licensed.
|
||||
|
||||
Some of the code in this file is from Clover Edition:
|
||||
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
|
||||
|
||||
The license for Clover Edition is shown below:
|
||||
|
||||
Copyright (c) 2019 Nick Walton
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
|
||||
import multiprocessing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
|
@ -33,6 +62,8 @@ def settings_callback() -> dict:
|
|||
"top_k": 0,
|
||||
"tfs": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"rpslope": 0.0,
|
||||
"rprange": 0,
|
||||
}
|
||||
|
||||
def started_compiling_callback() -> None:
|
||||
|
@ -78,26 +109,40 @@ def __batch_xmap(shard_dim=1):
|
|||
return inner
|
||||
|
||||
|
||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
|
||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
'''
|
||||
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
|
||||
rpslope = np.int32(rpslope)
|
||||
rprange = np.int32(rprange)
|
||||
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
|
||||
penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = np.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
if rpslope != 0.0 and rprange > 0:
|
||||
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
repetition_penalty = _penalty
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
penalty_logits = np.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
penalty_arange >= 0,
|
||||
np.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
|
@ -202,25 +247,46 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||
# probability distribution)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty):
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
'''
|
||||
rpslope = jnp.int32(rpslope)
|
||||
rprange = jnp.int32(rprange)
|
||||
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
|
||||
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
penalty_logits = jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
|
@ -325,7 +391,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||
|
||||
pad_token_id = 50256
|
||||
|
||||
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
|
||||
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options):
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
gi = data[0][1]
|
||||
def sample_loop_fn(carry):
|
||||
|
@ -339,7 +405,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op
|
|||
logits = apply_repetition_penalty_dynamic(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
|
@ -401,6 +471,8 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
||||
# Get repetition penalty from the arguments
|
||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||
rpslope = sampler_options.pop('rpslope', None)
|
||||
rprange = sampler_options.pop('rprange', None)
|
||||
# This is the main generation loop
|
||||
def generate_loop_fn(carry):
|
||||
# Unpack current generate_loop_fn state
|
||||
|
@ -427,7 +499,11 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
logits = apply_repetition_penalty_static(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
|
@ -586,7 +662,9 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
sample_data[i][2] = logits[i]
|
||||
sampler_options = settings_callback()
|
||||
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||
rpslope = sampler_options.pop("rpslope", 0.0)
|
||||
rprange = sampler_options.pop("rprange", 0)
|
||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options)
|
||||
n_generated += 1
|
||||
for i in range(numseqs):
|
||||
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||
|
@ -659,6 +737,8 @@ def infer_static(
|
|||
top_k=0,
|
||||
tfs=1.0,
|
||||
repetition_penalty=1.0,
|
||||
rpslope=0.0,
|
||||
rprange=0,
|
||||
numseqs=1,
|
||||
gen_len=80,
|
||||
soft_embeddings: Optional[np.array] = None,
|
||||
|
@ -679,6 +759,8 @@ def infer_static(
|
|||
"top_p": top_p * np.ones(total_batch),
|
||||
"tfs": tfs * np.ones(total_batch),
|
||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||
"rpslope": rpslope * np.ones(total_batch),
|
||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
||||
}
|
||||
output = network.generate_static(
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
'''
|
||||
This file is AGPL-licensed.
|
||||
|
||||
Some of the code in this file is from Clover Edition:
|
||||
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
|
||||
|
||||
The license for Clover Edition is shown below:
|
||||
|
||||
Copyright (c) 2019 Nick Walton
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
|
||||
import torch
|
||||
from transformers import LogitsWarper, LogitsProcessor
|
||||
|
||||
|
||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self.penalty_range = int(self.penalty_range)
|
||||
clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)
|
||||
|
||||
if self.penalty != 1.0:
|
||||
if self.penalty_range > 0:
|
||||
if clipped_penalty_range < input_ids.shape[1]:
|
||||
input_ids = input_ids[..., -clipped_penalty_range:]
|
||||
|
||||
if self.penalty_slope != 0:
|
||||
_penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
|
||||
_penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
|
||||
self.penalty = _penalty[..., -clipped_penalty_range:]
|
||||
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
|
||||
scores.scatter_(1, input_ids, score)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Compute second derivative normalized CDF
|
||||
d2 = probs.diff().diff().abs()
|
||||
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||
|
||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||
sorted_indices_to_remove = torch.cat(
|
||||
(
|
||||
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
sorted_indices_to_remove,
|
||||
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
Loading…
Reference in New Issue