mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Repetition penalty slope and range
This commit is contained in:
100
aiserver.py
100
aiserver.py
@ -23,6 +23,7 @@ import packaging
|
||||
import contextlib
|
||||
import traceback
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
@ -100,6 +101,8 @@ class vars:
|
||||
genamt = 80 # Amount of text for each action to generate
|
||||
ikgen = 200 # Number of characters for InferKit to generate
|
||||
rep_pen = 1.1 # Default generator repetition_penalty
|
||||
rep_pen_slope = 0.0 # Default generator repetition penalty slope
|
||||
rep_pen_range = 0 # Default generator repetition penalty range
|
||||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
@ -696,65 +699,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||
old_call = cls.__call__
|
||||
def new_call(self, *args, **kwargs):
|
||||
setattr(self, field_name, getattr(vars, var_name))
|
||||
if(not isinstance(field_name, str) and isinstance(field_name, Iterable)):
|
||||
conds = []
|
||||
for f, v in zip(field_name, var_name):
|
||||
conds.append(getattr(vars, v))
|
||||
setattr(self, f, conds[-1])
|
||||
else:
|
||||
conds = getattr(vars, var_name)
|
||||
setattr(self, field_name, conds)
|
||||
assert len(args) == 2
|
||||
if(cond is None or cond(getattr(vars, var_name))):
|
||||
if(cond is None or cond(conds)):
|
||||
return old_call(self, *args, **kwargs)
|
||||
return args[1]
|
||||
cls.__call__ = new_call
|
||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
|
||||
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
self.tfs = vars.tfs
|
||||
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Compute second derivative normalized CDF
|
||||
d2 = probs.diff().diff().abs()
|
||||
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||
|
||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||
sorted_indices_to_remove = torch.cat(
|
||||
(
|
||||
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
sorted_indices_to_remove,
|
||||
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self):
|
||||
@ -1072,6 +1042,8 @@ else:
|
||||
"top_k": int(vars.top_k),
|
||||
"tfs": float(vars.tfs),
|
||||
"repetition_penalty": float(vars.rep_pen),
|
||||
"rpslope": float(vars.rep_pen_slope),
|
||||
"rprange": int(vars.rep_pen_range),
|
||||
}
|
||||
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
@ -1418,6 +1390,8 @@ def lua_has_setting(setting):
|
||||
"settopk",
|
||||
"settfs",
|
||||
"setreppen",
|
||||
"setreppenslope",
|
||||
"setreppenrange",
|
||||
"settknmax",
|
||||
"setwidepth",
|
||||
"setuseprompt",
|
||||
@ -1433,6 +1407,8 @@ def lua_has_setting(setting):
|
||||
"top_k",
|
||||
"tfs",
|
||||
"reppen",
|
||||
"reppenslope",
|
||||
"reppenrange",
|
||||
"tknmax",
|
||||
"widepth",
|
||||
"useprompt",
|
||||
@ -1464,6 +1440,8 @@ def lua_get_setting(setting):
|
||||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||
if(setting in ("settknmax", "tknmax")): return vars.max_length
|
||||
if(setting == "anotedepth"): return vars.andepth
|
||||
if(setting in ("setwidepth", "widepth")): return vars.widepth
|
||||
@ -1495,6 +1473,8 @@ def lua_set_setting(setting, v):
|
||||
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
|
||||
if(setting == "anotedepth"): vars.andepth = v; return True
|
||||
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
|
||||
@ -1881,6 +1861,16 @@ def get_message(msg):
|
||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setreppenslope'):
|
||||
vars.rep_pen_slope = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setreppenrange'):
|
||||
vars.rep_pen_range = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setoutput'):
|
||||
vars.genamt = int(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
|
||||
@ -2151,6 +2141,8 @@ def savesettings():
|
||||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||
js["rep_pen_range"] = vars.rep_pen_range
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
js["ikgen"] = vars.ikgen
|
||||
@ -2206,6 +2198,10 @@ def loadsettings():
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("genamt" in js):
|
||||
vars.genamt = js["genamt"]
|
||||
if("max_length" in js):
|
||||
@ -2285,6 +2281,10 @@ def loadmodelsettings():
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("chatmode" in js):
|
||||
@ -2958,6 +2958,8 @@ def sendtocolab(txt, min, max):
|
||||
'min': min,
|
||||
'max': max,
|
||||
'rep_pen': vars.rep_pen,
|
||||
'rep_pen_slope': vars.rep_pen_slope,
|
||||
'rep_pen_range': vars.rep_pen_range,
|
||||
'temperature': vars.temp,
|
||||
'top_p': vars.top_p,
|
||||
'top_k': vars.top_k,
|
||||
@ -3099,6 +3101,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
tfs=vars.tfs,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
rpslope=vars.rep_pen_slope,
|
||||
rprange=vars.rep_pen_range,
|
||||
soft_embeddings=vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
)
|
||||
@ -3281,6 +3285,8 @@ def refresh_settings():
|
||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)
|
||||
|
Reference in New Issue
Block a user