Merge branch 'henk717:united' into united

This commit is contained in:
ebolam
2022-01-26 11:27:12 -05:00
committed by GitHub
11 changed files with 414 additions and 78 deletions

View File

@ -1,7 +1,7 @@
#!/usr/bin/python3
#==================================================================#
# KoboldAI
# Version: 1.16.4
# Version: 1.17.0
# By: KoboldAIDev and the KoboldAI Community
#==================================================================#
@ -23,6 +23,7 @@ import packaging
import contextlib
import traceback
import threading
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests
@ -100,6 +101,8 @@ class vars:
genamt = 80 # Amount of text for each action to generate
ikgen = 200 # Number of characters for InferKit to generate
rep_pen = 1.1 # Default generator repetition_penalty
rep_pen_slope = 0.0 # Default generator repetition penalty slope
rep_pen_range = 0 # Default generator repetition penalty range
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
@ -709,65 +712,32 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
def new_call(self, *args, **kwargs):
setattr(self, field_name, getattr(vars, var_name))
if(not isinstance(field_name, str) and isinstance(field_name, Iterable)):
conds = []
for f, v in zip(field_name, var_name):
conds.append(getattr(vars, v))
setattr(self, f, conds[-1])
else:
conds = getattr(vars, var_name)
setattr(self, field_name, conds)
assert len(args) == 2
if(cond is None or cond(getattr(vars, var_name))):
if(cond is None or cond(conds)):
return old_call(self, *args, **kwargs)
return args[1]
cls.__call__ = new_call
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
self.tfs = tfs
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.tfs = vars.tfs
if self.filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Compute second derivative normalized CDF
d2 = probs.diff().diff().abs()
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat(
(
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
sorted_indices_to_remove,
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
),
dim=-1,
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
@ -1085,6 +1055,8 @@ else:
"top_k": int(vars.top_k),
"tfs": float(vars.tfs),
"repetition_penalty": float(vars.rep_pen),
"rpslope": float(vars.rep_pen_slope),
"rprange": int(vars.rep_pen_range),
}
# If we're running Colab or OAI, we still need a tokenizer.
@ -1432,6 +1404,8 @@ def lua_has_setting(setting):
"settopk",
"settfs",
"setreppen",
"setreppenslope",
"setreppenrange",
"settknmax",
"setwidepth",
"setuseprompt",
@ -1447,6 +1421,8 @@ def lua_has_setting(setting):
"top_k",
"tfs",
"reppen",
"reppenslope",
"reppenrange",
"tknmax",
"widepth",
"useprompt",
@ -1478,6 +1454,8 @@ def lua_get_setting(setting):
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
if(setting in ("settfs", "tfs")): return vars.tfs
if(setting in ("setreppen", "reppen")): return vars.rep_pen
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
if(setting in ("settknmax", "tknmax")): return vars.max_length
if(setting == "anotedepth"): return vars.andepth
if(setting in ("setwidepth", "widepth")): return vars.widepth
@ -1509,6 +1487,8 @@ def lua_set_setting(setting, v):
if(setting in ("settopk", "topk")): vars.top_k = v
if(setting in ("settfs", "tfs")): vars.tfs = v
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
if(setting == "anotedepth"): vars.andepth = v; return True
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
@ -1906,6 +1886,16 @@ def get_message(msg):
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppenslope'):
vars.rep_pen_slope = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setreppenrange'):
vars.rep_pen_range = float(msg['data'])
emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True)
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setoutput'):
vars.genamt = int(msg['data'])
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
@ -2183,6 +2173,8 @@ def savesettings():
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
@ -2238,6 +2230,10 @@ def loadsettings():
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js):
vars.genamt = js["genamt"]
if("max_length" in js):
@ -2303,7 +2299,10 @@ def loadmodelsettings():
model_js_config = str(model_config).partition(' ')[2]
js = json.loads(model_js_config)
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
try:
model_js_config = open(vars.custmodpth + "/config.json", "r")
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_js_config)
if("badwordsids" in js):
vars.badwordsids = js["badwordsids"]
@ -2317,6 +2316,10 @@ def loadmodelsettings():
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
@ -3091,6 +3094,8 @@ def sendtocolab(txt, min, max):
'min': min,
'max': max,
'rep_pen': vars.rep_pen,
'rep_pen_slope': vars.rep_pen_slope,
'rep_pen_range': vars.rep_pen_range,
'temperature': vars.temp,
'top_p': vars.top_p,
'top_k': vars.top_k,
@ -3233,6 +3238,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
rprange=vars.rep_pen_range,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
@ -3415,6 +3422,8 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)