Merge pull request #148 from VE-FORBRYDERNE/overhaul-merge
Merge united into overhaul
This commit is contained in:
commit
f3eb7cba5c
67
aiserver.py
67
aiserver.py
|
@ -106,7 +106,6 @@ model_menu = {
|
||||||
["Adventure Models", "adventurelist", "", True],
|
["Adventure Models", "adventurelist", "", True],
|
||||||
["Novel Models", "novellist", "", True],
|
["Novel Models", "novellist", "", True],
|
||||||
["NSFW Models", "nsfwlist", "", True],
|
["NSFW Models", "nsfwlist", "", True],
|
||||||
["Chatbot Models", "chatlist", "", True],
|
|
||||||
["Untuned GPT-Neo/J", "gptneolist", "", True],
|
["Untuned GPT-Neo/J", "gptneolist", "", True],
|
||||||
["Untuned Fairseq Dense", "fsdlist", "", True],
|
["Untuned Fairseq Dense", "fsdlist", "", True],
|
||||||
["Untuned OPT", "optlist", "", True],
|
["Untuned OPT", "optlist", "", True],
|
||||||
|
@ -220,6 +219,7 @@ class vars:
|
||||||
temp = 0.5 # Default generator temperature
|
temp = 0.5 # Default generator temperature
|
||||||
top_p = 0.9 # Default generator top_p
|
top_p = 0.9 # Default generator top_p
|
||||||
top_k = 0 # Default generator top_k
|
top_k = 0 # Default generator top_k
|
||||||
|
top_a = 0.0 # Default generator top-a
|
||||||
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||||
typical = 1.0 # Default generator typical sampling threshold
|
typical = 1.0 # Default generator typical sampling threshold
|
||||||
numseqs = 1 # Number of sequences to ask the generator to create
|
numseqs = 1 # Number of sequences to ask the generator to create
|
||||||
|
@ -315,6 +315,7 @@ class vars:
|
||||||
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
||||||
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
||||||
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
||||||
|
sampler_order = utils.default_sampler_order.copy()
|
||||||
chatmode = False
|
chatmode = False
|
||||||
chatname = "You"
|
chatname = "You"
|
||||||
adventure = False
|
adventure = False
|
||||||
|
@ -647,6 +648,8 @@ def loadmodelsettings():
|
||||||
vars.badwordsids = js["badwordsids"]
|
vars.badwordsids = js["badwordsids"]
|
||||||
if("nobreakmodel" in js):
|
if("nobreakmodel" in js):
|
||||||
vars.nobreakmodel = js["nobreakmodel"]
|
vars.nobreakmodel = js["nobreakmodel"]
|
||||||
|
if("sampler_order" in js):
|
||||||
|
vars.sampler_order = js["sampler_order"]
|
||||||
if("temp" in js):
|
if("temp" in js):
|
||||||
vars.temp = js["temp"]
|
vars.temp = js["temp"]
|
||||||
if("top_p" in js):
|
if("top_p" in js):
|
||||||
|
@ -657,6 +660,8 @@ def loadmodelsettings():
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
if("typical" in js):
|
if("typical" in js):
|
||||||
vars.typical = js["typical"]
|
vars.typical = js["typical"]
|
||||||
|
if("top_a" in js):
|
||||||
|
vars.top_a = js["top_a"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
if("rep_pen_slope" in js):
|
if("rep_pen_slope" in js):
|
||||||
|
@ -688,11 +693,13 @@ def savesettings():
|
||||||
js = {}
|
js = {}
|
||||||
js["apikey"] = vars.apikey
|
js["apikey"] = vars.apikey
|
||||||
js["andepth"] = vars.andepth
|
js["andepth"] = vars.andepth
|
||||||
|
js["sampler_order"] = vars.sampler_order
|
||||||
js["temp"] = vars.temp
|
js["temp"] = vars.temp
|
||||||
js["top_p"] = vars.top_p
|
js["top_p"] = vars.top_p
|
||||||
js["top_k"] = vars.top_k
|
js["top_k"] = vars.top_k
|
||||||
js["tfs"] = vars.tfs
|
js["tfs"] = vars.tfs
|
||||||
js["typical"] = vars.typical
|
js["typical"] = vars.typical
|
||||||
|
js["top_a"] = vars.top_a
|
||||||
js["rep_pen"] = vars.rep_pen
|
js["rep_pen"] = vars.rep_pen
|
||||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||||
js["rep_pen_range"] = vars.rep_pen_range
|
js["rep_pen_range"] = vars.rep_pen_range
|
||||||
|
@ -763,6 +770,8 @@ def processsettings(js):
|
||||||
vars.apikey = js["apikey"]
|
vars.apikey = js["apikey"]
|
||||||
if("andepth" in js):
|
if("andepth" in js):
|
||||||
vars.andepth = js["andepth"]
|
vars.andepth = js["andepth"]
|
||||||
|
if("sampler_order" in js):
|
||||||
|
vars.sampler_order = js["sampler_order"]
|
||||||
if("temp" in js):
|
if("temp" in js):
|
||||||
vars.temp = js["temp"]
|
vars.temp = js["temp"]
|
||||||
if("top_p" in js):
|
if("top_p" in js):
|
||||||
|
@ -773,6 +782,8 @@ def processsettings(js):
|
||||||
vars.tfs = js["tfs"]
|
vars.tfs = js["tfs"]
|
||||||
if("typical" in js):
|
if("typical" in js):
|
||||||
vars.typical = js["typical"]
|
vars.typical = js["typical"]
|
||||||
|
if("top_a" in js):
|
||||||
|
vars.top_a = js["top_a"]
|
||||||
if("rep_pen" in js):
|
if("rep_pen" in js):
|
||||||
vars.rep_pen = js["rep_pen"]
|
vars.rep_pen = js["rep_pen"]
|
||||||
if("rep_pen_slope" in js):
|
if("rep_pen_slope" in js):
|
||||||
|
@ -1268,7 +1279,7 @@ def patch_transformers():
|
||||||
|
|
||||||
# Patch transformers to use our custom logit warpers
|
# Patch transformers to use our custom logit warpers
|
||||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
|
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
|
||||||
|
|
||||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||||
old_call = cls.__call__
|
old_call = cls.__call__
|
||||||
|
@ -1288,6 +1299,7 @@ def patch_transformers():
|
||||||
cls.__call__ = new_call
|
cls.__call__ = new_call
|
||||||
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
||||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||||
|
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||||
|
@ -1331,14 +1343,23 @@ def patch_transformers():
|
||||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||||
|
|
||||||
|
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||||
|
def __init__(self, beams: int = 1, **kwargs):
|
||||||
|
self.__warper_list: List[LogitsWarper] = []
|
||||||
|
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||||
|
for k in vars.sampler_order:
|
||||||
|
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
|
return scores
|
||||||
|
|
||||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||||
warper_list = LogitsProcessorList()
|
return KoboldLogitsWarperList(beams=beams)
|
||||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
|
||||||
return warper_list
|
|
||||||
|
|
||||||
def new_sample(self, *args, **kwargs):
|
def new_sample(self, *args, **kwargs):
|
||||||
assert kwargs.pop("logits_warper", None) is not None
|
assert kwargs.pop("logits_warper", None) is not None
|
||||||
|
@ -1957,11 +1978,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
||||||
|
|
||||||
def tpumtjgenerate_settings_callback() -> dict:
|
def tpumtjgenerate_settings_callback() -> dict:
|
||||||
return {
|
return {
|
||||||
|
"sampler_order": vars.sampler_order,
|
||||||
"top_p": float(vars.top_p),
|
"top_p": float(vars.top_p),
|
||||||
"temp": float(vars.temp),
|
"temp": float(vars.temp),
|
||||||
"top_k": int(vars.top_k),
|
"top_k": int(vars.top_k),
|
||||||
"tfs": float(vars.tfs),
|
"tfs": float(vars.tfs),
|
||||||
"typical": float(vars.typical),
|
"typical": float(vars.typical),
|
||||||
|
"top_a": float(vars.top_a),
|
||||||
"repetition_penalty": float(vars.rep_pen),
|
"repetition_penalty": float(vars.rep_pen),
|
||||||
"rpslope": float(vars.rep_pen_slope),
|
"rpslope": float(vars.rep_pen_slope),
|
||||||
"rprange": int(vars.rep_pen_range),
|
"rprange": int(vars.rep_pen_range),
|
||||||
|
@ -2384,6 +2407,7 @@ def lua_has_setting(setting):
|
||||||
"settopk",
|
"settopk",
|
||||||
"settfs",
|
"settfs",
|
||||||
"settypical",
|
"settypical",
|
||||||
|
"settopa",
|
||||||
"setreppen",
|
"setreppen",
|
||||||
"setreppenslope",
|
"setreppenslope",
|
||||||
"setreppenrange",
|
"setreppenrange",
|
||||||
|
@ -2403,6 +2427,7 @@ def lua_has_setting(setting):
|
||||||
"top_k",
|
"top_k",
|
||||||
"tfs",
|
"tfs",
|
||||||
"typical",
|
"typical",
|
||||||
|
"topa",
|
||||||
"reppen",
|
"reppen",
|
||||||
"reppenslope",
|
"reppenslope",
|
||||||
"reppenrange",
|
"reppenrange",
|
||||||
|
@ -2437,6 +2462,7 @@ def lua_get_setting(setting):
|
||||||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||||
if(setting in ("settypical", "typical")): return vars.typical
|
if(setting in ("settypical", "typical")): return vars.typical
|
||||||
|
if(setting in ("settopa", "topa")): return vars.top_a
|
||||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||||
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||||
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||||
|
@ -2472,6 +2498,7 @@ def lua_set_setting(setting, v):
|
||||||
if(setting in ("settopk", "topk")): vars.top_k = v
|
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||||
if(setting in ("settypical", "typical")): vars.typical = v
|
if(setting in ("settypical", "typical")): vars.typical = v
|
||||||
|
if(setting in ("settopa", "topa")): vars.top_a = v
|
||||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||||
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||||
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||||
|
@ -2862,6 +2889,11 @@ def get_message(msg):
|
||||||
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
|
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
|
||||||
settingschanged()
|
settingschanged()
|
||||||
refresh_settings()
|
refresh_settings()
|
||||||
|
elif(msg['cmd'] == 'settopa'):
|
||||||
|
vars.top_a = float(msg['data'])
|
||||||
|
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
|
||||||
|
settingschanged()
|
||||||
|
refresh_settings()
|
||||||
elif(msg['cmd'] == 'setreppen'):
|
elif(msg['cmd'] == 'setreppen'):
|
||||||
vars.rep_pen = float(msg['data'])
|
vars.rep_pen = float(msg['data'])
|
||||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||||
|
@ -3015,6 +3047,8 @@ def get_message(msg):
|
||||||
elif(msg['cmd'] == 'uslistrequest'):
|
elif(msg['cmd'] == 'uslistrequest'):
|
||||||
unloaded, loaded = getuslist()
|
unloaded, loaded = getuslist()
|
||||||
emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}})
|
emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}})
|
||||||
|
elif(msg['cmd'] == 'samplerlistrequest'):
|
||||||
|
emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order})
|
||||||
elif(msg['cmd'] == 'usloaded'):
|
elif(msg['cmd'] == 'usloaded'):
|
||||||
vars.userscripts = []
|
vars.userscripts = []
|
||||||
for userscript in msg['data']:
|
for userscript in msg['data']:
|
||||||
|
@ -3028,6 +3062,16 @@ def get_message(msg):
|
||||||
load_lua_scripts()
|
load_lua_scripts()
|
||||||
unloaded, loaded = getuslist()
|
unloaded, loaded = getuslist()
|
||||||
sendUSStatItems()
|
sendUSStatItems()
|
||||||
|
elif(msg['cmd'] == 'samplers'):
|
||||||
|
sampler_order = msg["data"]
|
||||||
|
if(not isinstance(sampler_order, list)):
|
||||||
|
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
|
||||||
|
if(len(sampler_order) != len(vars.sampler_order)):
|
||||||
|
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
|
||||||
|
if(not all(isinstance(e, int) for e in sampler_order)):
|
||||||
|
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
|
||||||
|
vars.sampler_order = sampler_order
|
||||||
|
settingschanged()
|
||||||
elif(msg['cmd'] == 'list_model'):
|
elif(msg['cmd'] == 'list_model'):
|
||||||
sendModelSelection(menu=msg['data'])
|
sendModelSelection(menu=msg['data'])
|
||||||
elif(msg['cmd'] == 'load_model'):
|
elif(msg['cmd'] == 'load_model'):
|
||||||
|
@ -3988,6 +4032,7 @@ def sendtocolab(txt, min, max):
|
||||||
'top_k': vars.top_k,
|
'top_k': vars.top_k,
|
||||||
'tfs': vars.tfs,
|
'tfs': vars.tfs,
|
||||||
'typical': vars.typical,
|
'typical': vars.typical,
|
||||||
|
'topa': vars.top_a,
|
||||||
'numseqs': vars.numseqs,
|
'numseqs': vars.numseqs,
|
||||||
'retfultxt': False
|
'retfultxt': False
|
||||||
}
|
}
|
||||||
|
@ -4125,12 +4170,14 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
top_k=vars.top_k,
|
top_k=vars.top_k,
|
||||||
tfs=vars.tfs,
|
tfs=vars.tfs,
|
||||||
typical=vars.typical,
|
typical=vars.typical,
|
||||||
|
top_a=vars.top_a,
|
||||||
numseqs=vars.numseqs,
|
numseqs=vars.numseqs,
|
||||||
repetition_penalty=vars.rep_pen,
|
repetition_penalty=vars.rep_pen,
|
||||||
rpslope=vars.rep_pen_slope,
|
rpslope=vars.rep_pen_slope,
|
||||||
rprange=vars.rep_pen_range,
|
rprange=vars.rep_pen_range,
|
||||||
soft_embeddings=vars.sp,
|
soft_embeddings=vars.sp,
|
||||||
soft_tokens=soft_tokens,
|
soft_tokens=soft_tokens,
|
||||||
|
sampler_order=vars.sampler_order,
|
||||||
)
|
)
|
||||||
past = genout
|
past = genout
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
|
@ -4311,6 +4358,7 @@ def refresh_settings():
|
||||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
|
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
|
||||||
|
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||||
|
@ -4887,6 +4935,7 @@ def oairequest(txt, min, max):
|
||||||
'prompt': txt,
|
'prompt': txt,
|
||||||
'max_tokens': vars.genamt,
|
'max_tokens': vars.genamt,
|
||||||
'temperature': vars.temp,
|
'temperature': vars.temp,
|
||||||
|
'top_a': vars.top_a,
|
||||||
'top_p': vars.top_p,
|
'top_p': vars.top_p,
|
||||||
'top_k': vars.top_k,
|
'top_k': vars.top_k,
|
||||||
'tfs': vars.tfs,
|
'tfs': vars.tfs,
|
||||||
|
|
|
@ -867,6 +867,7 @@ return function(_python, _bridged)
|
||||||
---@field settopk integer
|
---@field settopk integer
|
||||||
---@field settfs number
|
---@field settfs number
|
||||||
---@field settypical number
|
---@field settypical number
|
||||||
|
---@field settopa number
|
||||||
---@field setreppen number
|
---@field setreppen number
|
||||||
---@field setreppenslope number
|
---@field setreppenslope number
|
||||||
---@field setreppenrange number
|
---@field setreppenrange number
|
||||||
|
@ -884,6 +885,7 @@ return function(_python, _bridged)
|
||||||
---@field top_k integer
|
---@field top_k integer
|
||||||
---@field tfs number
|
---@field tfs number
|
||||||
---@field typical number
|
---@field typical number
|
||||||
|
---@field topa number
|
||||||
---@field reppen number
|
---@field reppen number
|
||||||
---@field reppenslope number
|
---@field reppenslope number
|
||||||
---@field reppenrange number
|
---@field reppenrange number
|
||||||
|
|
|
@ -64,6 +64,17 @@ gensettingstf = [
|
||||||
"step": 0.05,
|
"step": 0.05,
|
||||||
"default": 1.0,
|
"default": 1.0,
|
||||||
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
|
"tooltip": "Alternative sampling method described in the paper \"Typical Decoding for Natural Language Generation\" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"uitype": "slider",
|
||||||
|
"unit": "float",
|
||||||
|
"label": "Top a Sampling",
|
||||||
|
"id": "settopa",
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"default": 0.0,
|
||||||
|
"tooltip": "Alternative sampling method that reduces the randomness of the AI whenever the probability of one token is much higher than all the others. Higher values have a stronger effect. Set this setting to 0 to disable its effect."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"uitype": "slider",
|
"uitype": "slider",
|
||||||
|
|
|
@ -21,6 +21,7 @@ var button_settings;
|
||||||
var button_format;
|
var button_format;
|
||||||
var button_softprompt;
|
var button_softprompt;
|
||||||
var button_userscripts;
|
var button_userscripts;
|
||||||
|
var button_samplers;
|
||||||
var button_mode;
|
var button_mode;
|
||||||
var button_mode_label;
|
var button_mode_label;
|
||||||
var button_send;
|
var button_send;
|
||||||
|
@ -112,6 +113,9 @@ var do_clear_ent = false;
|
||||||
// Whether or not an entry in the Userscripts menu is being dragged
|
// Whether or not an entry in the Userscripts menu is being dragged
|
||||||
var us_dragging = false;
|
var us_dragging = false;
|
||||||
|
|
||||||
|
// Whether or not an entry in the Samplers menu is being dragged
|
||||||
|
var samplers_dragging = false;
|
||||||
|
|
||||||
// Display vars
|
// Display vars
|
||||||
var allowtoggle = false;
|
var allowtoggle = false;
|
||||||
var formatcount = 0;
|
var formatcount = 0;
|
||||||
|
@ -997,6 +1001,16 @@ function hideUSPopup() {
|
||||||
spcontent.html("");
|
spcontent.html("");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function showSamplersPopup() {
|
||||||
|
samplerspopup.removeClass("hidden");
|
||||||
|
samplerspopup.addClass("flex");
|
||||||
|
}
|
||||||
|
|
||||||
|
function hideSamplersPopup() {
|
||||||
|
samplerspopup.removeClass("flex");
|
||||||
|
samplerspopup.addClass("hidden");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
function buildLoadModelList(ar, menu, breadcrumbs) {
|
function buildLoadModelList(ar, menu, breadcrumbs) {
|
||||||
disableButtons([load_model_accept]);
|
disableButtons([load_model_accept]);
|
||||||
|
@ -1207,6 +1221,29 @@ function buildUSList(unloaded, loaded) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function buildSamplerList(samplers) {
|
||||||
|
samplerslist.html("");
|
||||||
|
showSamplersPopup();
|
||||||
|
var i;
|
||||||
|
var samplers_lookup_table = [
|
||||||
|
"Top-k Sampling",
|
||||||
|
"Top-a Sampling",
|
||||||
|
"Top-p Sampling",
|
||||||
|
"Tail-free Sampling",
|
||||||
|
"Typical Sampling",
|
||||||
|
"Temperature",
|
||||||
|
]
|
||||||
|
for(i=0; i<samplers.length; i++) {
|
||||||
|
samplerslist.append("<div class=\"flex\">\
|
||||||
|
<div class=\"samplerslistitem flex-row-container\" sid=\""+samplers[i]+"\">\
|
||||||
|
<div class=\"flex-row\">\
|
||||||
|
<div>"+samplers_lookup_table[samplers[i]]+"</div>\
|
||||||
|
</div>\
|
||||||
|
</div>\
|
||||||
|
</div>");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function highlightLoadLine(ref) {
|
function highlightLoadLine(ref) {
|
||||||
$("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected");
|
$("#loadlistcontent > div > div.popuplistselected").removeClass("popuplistselected");
|
||||||
$("#loadmodellistcontent > div > div.popuplistselected").removeClass("popuplistselected");
|
$("#loadmodellistcontent > div > div.popuplistselected").removeClass("popuplistselected");
|
||||||
|
@ -1963,6 +2000,7 @@ $(document).ready(function(){
|
||||||
button_format = $('#btn_format');
|
button_format = $('#btn_format');
|
||||||
button_softprompt = $("#btn_softprompt");
|
button_softprompt = $("#btn_softprompt");
|
||||||
button_userscripts= $("#btn_userscripts");
|
button_userscripts= $("#btn_userscripts");
|
||||||
|
button_samplers = $("#btn_samplers");
|
||||||
button_mode = $('#btnmode')
|
button_mode = $('#btnmode')
|
||||||
button_mode_label = $('#btnmode_label')
|
button_mode_label = $('#btnmode_label')
|
||||||
button_send = $('#btnsend');
|
button_send = $('#btnsend');
|
||||||
|
@ -2015,6 +2053,10 @@ $(document).ready(function(){
|
||||||
usloaded = $("#uslistloaded");
|
usloaded = $("#uslistloaded");
|
||||||
us_accept = $("#btn_usaccept");
|
us_accept = $("#btn_usaccept");
|
||||||
us_close = $("#btn_usclose");
|
us_close = $("#btn_usclose");
|
||||||
|
samplerspopup = $("#samplerscontainer");
|
||||||
|
samplerslist = $("#samplerslist");
|
||||||
|
samplers_accept = $("#btn_samplersaccept");
|
||||||
|
samplers_close = $("#btn_samplersclose");
|
||||||
nspopup = $("#newgamecontainer");
|
nspopup = $("#newgamecontainer");
|
||||||
ns_accept = $("#btn_nsaccept");
|
ns_accept = $("#btn_nsaccept");
|
||||||
ns_close = $("#btn_nsclose");
|
ns_close = $("#btn_nsclose");
|
||||||
|
@ -2038,7 +2080,7 @@ $(document).ready(function(){
|
||||||
modelname = msg.modelname;
|
modelname = msg.modelname;
|
||||||
}
|
}
|
||||||
refreshTitle();
|
refreshTitle();
|
||||||
connect_status.html("<b>Connected to KoboldAI Process!</b>");
|
connect_status.html("<b>Connected to KoboldAI!</b>");
|
||||||
connect_status.removeClass("color_orange");
|
connect_status.removeClass("color_orange");
|
||||||
connect_status.addClass("color_green");
|
connect_status.addClass("color_green");
|
||||||
// Reset Menus
|
// Reset Menus
|
||||||
|
@ -2231,6 +2273,10 @@ $(document).ready(function(){
|
||||||
// Send current typical value to input
|
// Send current typical value to input
|
||||||
$("#settypicalcur").val(msg.data);
|
$("#settypicalcur").val(msg.data);
|
||||||
$("#settypical").val(parseFloat(msg.data)).trigger("change");
|
$("#settypical").val(parseFloat(msg.data)).trigger("change");
|
||||||
|
} else if(msg.cmd == "updatetopa") {
|
||||||
|
// Send current top a value to input
|
||||||
|
$("#settopacur").val(msg.data);
|
||||||
|
$("#settopa").val(parseFloat(msg.data)).trigger("change");
|
||||||
} else if(msg.cmd == "updatereppen") {
|
} else if(msg.cmd == "updatereppen") {
|
||||||
// Send current rep pen value to input
|
// Send current rep pen value to input
|
||||||
$("#setreppencur").val(msg.data);
|
$("#setreppencur").val(msg.data);
|
||||||
|
@ -2270,6 +2316,9 @@ $(document).ready(function(){
|
||||||
} else if(msg.cmd == "setlabeltypical") {
|
} else if(msg.cmd == "setlabeltypical") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#settypicalcur").val(msg.data);
|
$("#settypicalcur").val(msg.data);
|
||||||
|
} else if(msg.cmd == "setlabeltypical") {
|
||||||
|
// Update setting label with value from server
|
||||||
|
$("#settopa").val(msg.data);
|
||||||
} else if(msg.cmd == "setlabelreppen") {
|
} else if(msg.cmd == "setlabelreppen") {
|
||||||
// Update setting label with value from server
|
// Update setting label with value from server
|
||||||
$("#setreppencur").val(msg.data);
|
$("#setreppencur").val(msg.data);
|
||||||
|
@ -2440,6 +2489,8 @@ $(document).ready(function(){
|
||||||
buildSPList(msg.data);
|
buildSPList(msg.data);
|
||||||
} else if(msg.cmd == "buildus") {
|
} else if(msg.cmd == "buildus") {
|
||||||
buildUSList(msg.data.unloaded, msg.data.loaded);
|
buildUSList(msg.data.unloaded, msg.data.loaded);
|
||||||
|
} else if(msg.cmd == "buildsamplers") {
|
||||||
|
buildSamplerList(msg.data);
|
||||||
} else if(msg.cmd == "askforoverwrite") {
|
} else if(msg.cmd == "askforoverwrite") {
|
||||||
// Show overwrite warning
|
// Show overwrite warning
|
||||||
show([$(".saveasoverwrite")]);
|
show([$(".saveasoverwrite")]);
|
||||||
|
@ -2648,6 +2699,20 @@ $(document).ready(function(){
|
||||||
}, 10);
|
}, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var samplers_click_handler = function(ev) {
|
||||||
|
setTimeout(function() {
|
||||||
|
if (samplers_dragging) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var target = $(ev.target).closest(".samplerslistitem");
|
||||||
|
var next = target.parent().next().find(".samplerslistitem");
|
||||||
|
if (!next.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
next.parent().after(target.parent());
|
||||||
|
}, 10);
|
||||||
|
}
|
||||||
|
|
||||||
// Make the userscripts menu sortable
|
// Make the userscripts menu sortable
|
||||||
var us_sortable_settings = {
|
var us_sortable_settings = {
|
||||||
placeholder: "ussortable-placeholder",
|
placeholder: "ussortable-placeholder",
|
||||||
|
@ -2668,6 +2733,22 @@ $(document).ready(function(){
|
||||||
connectWith: "#uslistunloaded",
|
connectWith: "#uslistunloaded",
|
||||||
}, us_sortable_settings)).on("click", ".uslistitem", us_click_handler);
|
}, us_sortable_settings)).on("click", ".uslistitem", us_click_handler);
|
||||||
|
|
||||||
|
// Make the samplers menu sortable
|
||||||
|
var samplers_sortable_settings = {
|
||||||
|
placeholder: "samplerssortable-placeholder",
|
||||||
|
start: function() { samplers_dragging = true; },
|
||||||
|
stop: function() { samplers_dragging = false; },
|
||||||
|
delay: 2,
|
||||||
|
cursor: "move",
|
||||||
|
tolerance: "pointer",
|
||||||
|
opacity: 0.21,
|
||||||
|
revert: 173,
|
||||||
|
scrollSensitivity: 64,
|
||||||
|
scrollSpeed: 10,
|
||||||
|
}
|
||||||
|
samplerslist.sortable($.extend({
|
||||||
|
}, samplers_sortable_settings)).on("click", ".samplerslistitem", samplers_click_handler);
|
||||||
|
|
||||||
// Bind actions to UI buttons
|
// Bind actions to UI buttons
|
||||||
button_send.on("click", function(ev) {
|
button_send.on("click", function(ev) {
|
||||||
dosubmit();
|
dosubmit();
|
||||||
|
@ -2803,6 +2884,10 @@ $(document).ready(function(){
|
||||||
socket.send({'cmd': 'uslistrequest', 'data': ''});
|
socket.send({'cmd': 'uslistrequest', 'data': ''});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
button_samplers.on("click", function(ev) {
|
||||||
|
socket.send({'cmd': 'samplerlistrequest', 'data': ''});
|
||||||
|
});
|
||||||
|
|
||||||
load_close.on("click", function(ev) {
|
load_close.on("click", function(ev) {
|
||||||
hideLoadPopup();
|
hideLoadPopup();
|
||||||
});
|
});
|
||||||
|
@ -2859,6 +2944,16 @@ $(document).ready(function(){
|
||||||
hideUSPopup();
|
hideUSPopup();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
samplers_close.on("click", function(ev) {
|
||||||
|
hideSamplersPopup();
|
||||||
|
});
|
||||||
|
|
||||||
|
samplers_accept.on("click", function(ev) {
|
||||||
|
hideMessage();
|
||||||
|
socket.send({'cmd': 'samplers', 'data': samplerslist.find(".samplerslistitem").map(function() { return parseInt($(this).attr("sid")); }).toArray()});
|
||||||
|
hideSamplersPopup();
|
||||||
|
});
|
||||||
|
|
||||||
button_loadmodel.on("click", function(ev) {
|
button_loadmodel.on("click", function(ev) {
|
||||||
showLoadModelPopup();
|
showLoadModelPopup();
|
||||||
socket.send({'cmd': 'list_model', 'data': 'mainmenu'});
|
socket.send({'cmd': 'list_model', 'data': 'mainmenu'});
|
||||||
|
|
|
@ -457,6 +457,26 @@ body.connected #popupfooter, #popupfooter.always-available {
|
||||||
overflow-wrap: anywhere;
|
overflow-wrap: anywhere;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#samplerspopup {
|
||||||
|
width: 300px;
|
||||||
|
background-color: #262626;
|
||||||
|
margin-top: 100px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 768px) {
|
||||||
|
#samplerspopup {
|
||||||
|
width: 100%;
|
||||||
|
background-color: #262626;
|
||||||
|
margin-top: 100px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#samplerslist {
|
||||||
|
height: 300px;
|
||||||
|
overflow-y: scroll;
|
||||||
|
overflow-wrap: anywhere;
|
||||||
|
}
|
||||||
|
|
||||||
#nspopup {
|
#nspopup {
|
||||||
width: 350px;
|
width: 350px;
|
||||||
background-color: #262626;
|
background-color: #262626;
|
||||||
|
@ -750,7 +770,7 @@ body.connected .dropdown-item:hover, .dropdown-item.always-available:hover {
|
||||||
background-color: #3bf723;
|
background-color: #3bf723;
|
||||||
}
|
}
|
||||||
|
|
||||||
.ussortable-placeholder {
|
.ussortable-placeholder, .samplerssortable-placeholder {
|
||||||
height: 4px;
|
height: 4px;
|
||||||
background-color: #3bf723;
|
background-color: #3bf723;
|
||||||
}
|
}
|
||||||
|
@ -1362,7 +1382,7 @@ body.connected .popupfooter, .popupfooter.always-available {
|
||||||
background-color: #688f1f;
|
background-color: #688f1f;
|
||||||
}
|
}
|
||||||
|
|
||||||
.uslistitem {
|
.uslistitem, .samplerslistitem {
|
||||||
padding: 12px 10px 12px 10px;
|
padding: 12px 10px 12px 10px;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
|
@ -1374,11 +1394,11 @@ body.connected .popupfooter, .popupfooter.always-available {
|
||||||
transition: background-color 0.25s ease-in;
|
transition: background-color 0.25s ease-in;
|
||||||
}
|
}
|
||||||
|
|
||||||
.uslistitemsub {
|
.uslistitemsub, .samplerslistitemsub {
|
||||||
color: #ba9;
|
color: #ba9;
|
||||||
}
|
}
|
||||||
|
|
||||||
.uslistitem:hover {
|
.uslistitem:hover, .samplerslistitem:hover {
|
||||||
cursor: move;
|
cursor: move;
|
||||||
background-color: #688f1f;
|
background-color: #688f1f;
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
<link rel="stylesheet" href="static/bootstrap.min.css">
|
<link rel="stylesheet" href="static/bootstrap.min.css">
|
||||||
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
|
<link rel="stylesheet" href="static/bootstrap-toggle.min.css">
|
||||||
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
|
<link rel="stylesheet" href="static/open-iconic-bootstrap.min.css">
|
||||||
<link rel="stylesheet" href="static/custom.css?ver=1.18b">
|
<link rel="stylesheet" href="static/custom.css?ver=1.18c">
|
||||||
|
|
||||||
<script src="static/jquery-3.6.0.min.js"></script>
|
<script src="static/jquery-3.6.0.min.js"></script>
|
||||||
<script src="static/jquery-ui.sortable.min.js"></script>
|
<script src="static/jquery-ui.sortable.min.js"></script>
|
||||||
|
@ -17,7 +17,7 @@
|
||||||
<script src="static/bootstrap.min.js"></script>
|
<script src="static/bootstrap.min.js"></script>
|
||||||
<script src="static/bootstrap-toggle.min.js"></script>
|
<script src="static/bootstrap-toggle.min.js"></script>
|
||||||
<script src="static/rangy-core.min.js"></script>
|
<script src="static/rangy-core.min.js"></script>
|
||||||
<script src="static/application.js?ver=1.18c"></script>
|
<script src="static/application.js?ver=1.18e"></script>
|
||||||
<script src="static/favicon.js"></script>
|
<script src="static/favicon.js"></script>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
|
@ -81,6 +81,9 @@
|
||||||
<li class="nav-item">
|
<li class="nav-item">
|
||||||
<a class="nav-link" href="#" id="btn_format">Formatting</a>
|
<a class="nav-link" href="#" id="btn_format">Formatting</a>
|
||||||
</li>
|
</li>
|
||||||
|
<li class="nav-item">
|
||||||
|
<a class="nav-link" href="#" id="btn_samplers">Samplers</a>
|
||||||
|
</li>
|
||||||
<li class="nav-item">
|
<li class="nav-item">
|
||||||
<a class="nav-link" href="#" id="btn_userscripts">Userscripts</a>
|
<a class="nav-link" href="#" id="btn_userscripts">Userscripts</a>
|
||||||
</li>
|
</li>
|
||||||
|
@ -363,6 +366,19 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="popupcontainer hidden" id="samplerscontainer">
|
||||||
|
<div id="samplerspopup">
|
||||||
|
<div class="popuptitlebar">
|
||||||
|
<div class="popuptitletext">Drag-and-drop to change the order in which the samplers are applied</div>
|
||||||
|
</div>
|
||||||
|
<div id="samplerslist">
|
||||||
|
</div>
|
||||||
|
<div class="popupfooter">
|
||||||
|
<button type="button" class="btn btn-primary" id="btn_samplersaccept">Save</button>
|
||||||
|
<button type="button" class="btn btn-primary" id="btn_samplersclose">Cancel</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<div class="popupcontainer hidden" id="loadcontainerdelete">
|
<div class="popupcontainer hidden" id="loadcontainerdelete">
|
||||||
<div id="loadpopupdelete">
|
<div id="loadpopupdelete">
|
||||||
<div class="popuptitlebar">
|
<div class="popuptitlebar">
|
||||||
|
|
|
@ -65,11 +65,13 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List
|
||||||
|
|
||||||
def settings_callback() -> dict:
|
def settings_callback() -> dict:
|
||||||
return {
|
return {
|
||||||
|
"sampler_order": utils.default_sampler_order.copy(),
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
"temp": 0.5,
|
"temp": 0.5,
|
||||||
"top_k": 0,
|
"top_k": 0,
|
||||||
"tfs": 1.0,
|
"tfs": 1.0,
|
||||||
"typical": 1.0,
|
"typical": 1.0,
|
||||||
|
"top_a": 0.0,
|
||||||
"repetition_penalty": 1.0,
|
"repetition_penalty": 1.0,
|
||||||
"rpslope": 0.0,
|
"rpslope": 0.0,
|
||||||
"rprange": 0,
|
"rprange": 0,
|
||||||
|
@ -158,10 +160,10 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
||||||
logits[tokens] = penalty_logits
|
logits[tokens] = penalty_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
before picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
# Top-k (keep only the k tokens with the highest logits and remove
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
|
@ -180,8 +182,18 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if top_k > 0:
|
# Top-a (remove all tokens that have softmax probability less than
|
||||||
logits = top_k_filter(logits)
|
# a*m^2 where m is the maximum softmax probability)
|
||||||
|
def top_a_filter(logits):
|
||||||
|
# Replace every element in the logits array
|
||||||
|
# with e (Euler's number) to the power of that element, and divide
|
||||||
|
# each element of the new array by the sum of the elements in the
|
||||||
|
# new array
|
||||||
|
probabilities = np.array(jax.nn.softmax(logits), copy=True)
|
||||||
|
# Find the largest probability
|
||||||
|
probs_max = probabilities.max()
|
||||||
|
# Remove tokens
|
||||||
|
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
|
||||||
# Top-p (after sorting the remaining tokens again in descending order of
|
# Top-p (after sorting the remaining tokens again in descending order of
|
||||||
# logit, remove the ones that have cumulative softmax probability
|
# logit, remove the ones that have cumulative softmax probability
|
||||||
# greater than p)
|
# greater than p)
|
||||||
|
@ -207,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if top_p < 1.0:
|
|
||||||
logits = top_p_filter(logits)
|
|
||||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||||
# except it's the "cumulative normalized absolute second finite
|
# except it's the "cumulative normalized absolute second finite
|
||||||
# differences of the softmax probabilities" instead of just the
|
# differences of the softmax probabilities" instead of just the
|
||||||
|
@ -247,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if tfs < 1.0:
|
|
||||||
logits = tail_free_filter(logits)
|
|
||||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||||
def typical_filter(logits):
|
def typical_filter(logits):
|
||||||
# Compute softmax probabilities and the natural logarithms of them
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
@ -278,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -jnp.inf, logits)
|
return np.where(indices_to_remove, -jnp.inf, logits)
|
||||||
if typical < 1.0:
|
|
||||||
logits = typical_filter(logits)
|
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
logits /= temp
|
def temp_filter(logits):
|
||||||
|
return logits / temp
|
||||||
|
for k in sampler_order:
|
||||||
|
if k == 0 and top_k > 0: logits = top_k_filter(logits)
|
||||||
|
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
|
||||||
|
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
|
||||||
|
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
||||||
|
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||||
|
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
|
@ -332,10 +346,10 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
return logits.at[tokens].set(penalty_logits)
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0):
|
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 5 filters
|
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
before picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
# Top-k (keep only the k tokens with the highest logits and remove
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
|
@ -354,7 +368,18 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
|
# Top-a (remove all tokens that have softmax probability less than
|
||||||
|
# a*m^2 where m is the maximum softmax probability)
|
||||||
|
def top_a_filter(logits):
|
||||||
|
# Replace every element in the logits array
|
||||||
|
# with e (Euler's number) to the power of that element, and divide
|
||||||
|
# each element of the new array by the sum of the elements in the
|
||||||
|
# new array
|
||||||
|
probabilities = jax.nn.softmax(logits)
|
||||||
|
# Find the largest probability
|
||||||
|
probs_max = probabilities.max()
|
||||||
|
# Remove tokens
|
||||||
|
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
|
||||||
# Top-p (after sorting the remaining tokens again in descending order of
|
# Top-p (after sorting the remaining tokens again in descending order of
|
||||||
# logit, remove the ones that have cumulative softmax probability
|
# logit, remove the ones that have cumulative softmax probability
|
||||||
# greater than p)
|
# greater than p)
|
||||||
|
@ -380,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
|
|
||||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||||
# except it's the "cumulative normalized absolute second finite
|
# except it's the "cumulative normalized absolute second finite
|
||||||
# differences of the softmax probabilities" instead of just the
|
# differences of the softmax probabilities" instead of just the
|
||||||
|
@ -419,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
|
||||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||||
def typical_filter(logits):
|
def typical_filter(logits):
|
||||||
# Compute softmax probabilities and the natural logarithms of them
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
@ -448,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
|
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
def temp_filter(logits):
|
def temp_filter(logits):
|
||||||
return logits / temp
|
return logits / temp
|
||||||
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
|
for k in sampler_order:
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
|
@ -806,6 +834,7 @@ def infer_static(
|
||||||
top_k=0,
|
top_k=0,
|
||||||
tfs=1.0,
|
tfs=1.0,
|
||||||
typical=1.0,
|
typical=1.0,
|
||||||
|
top_a=0.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
rpslope=0.0,
|
rpslope=0.0,
|
||||||
rprange=0,
|
rprange=0,
|
||||||
|
@ -813,8 +842,12 @@ def infer_static(
|
||||||
gen_len=80,
|
gen_len=80,
|
||||||
soft_embeddings: Optional[np.array] = None,
|
soft_embeddings: Optional[np.array] = None,
|
||||||
soft_tokens: Optional[np.array] = None,
|
soft_tokens: Optional[np.array] = None,
|
||||||
|
sampler_order: Optional[List[int]] = None,
|
||||||
) -> List[np.array]:
|
) -> List[np.array]:
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
if sampler_order is None:
|
||||||
|
sampler_order = utils.default_sampler_order.copy()
|
||||||
|
sampler_order = np.uint32(sampler_order)
|
||||||
total_batch = 1
|
total_batch = 1
|
||||||
tokens = context
|
tokens = context
|
||||||
if(soft_tokens is not None):
|
if(soft_tokens is not None):
|
||||||
|
@ -825,10 +858,12 @@ def infer_static(
|
||||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
samples = []
|
samples = []
|
||||||
batched_generator_params = {
|
batched_generator_params = {
|
||||||
|
"sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0),
|
||||||
"temp": temp * np.ones(total_batch),
|
"temp": temp * np.ones(total_batch),
|
||||||
"top_p": top_p * np.ones(total_batch),
|
"top_p": top_p * np.ones(total_batch),
|
||||||
"tfs": tfs * np.ones(total_batch),
|
"tfs": tfs * np.ones(total_batch),
|
||||||
"typical": typical * np.ones(total_batch),
|
"typical": typical * np.ones(total_batch),
|
||||||
|
"top_a": top_a * np.ones(total_batch),
|
||||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||||
"rpslope": rpslope * np.ones(total_batch),
|
"rpslope": rpslope * np.ones(total_batch),
|
||||||
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
|
||||||
|
@ -985,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params
|
global thread_resources_env, seq, tokenizer, network, params
|
||||||
|
|
||||||
|
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
|
||||||
|
vars.sampler_order = utils.default_sampler_order.copy()
|
||||||
|
|
||||||
default_params = {
|
default_params = {
|
||||||
"compat": "j",
|
"compat": "j",
|
||||||
"layers": 28,
|
"layers": 28,
|
||||||
|
|
2
utils.py
2
utils.py
|
@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None
|
||||||
from_pretrained_kwargs = {}
|
from_pretrained_kwargs = {}
|
||||||
bar = None
|
bar = None
|
||||||
|
|
||||||
|
default_sampler_order = [0, 1, 2, 3, 4, 5]
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
# at least x seconds have passed without the function being called
|
# at least x seconds have passed without the function being called
|
||||||
|
|
29
warpers.py
29
warpers.py
|
@ -148,3 +148,32 @@ class TypicalLogitsWarper(LogitsWarper):
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TopALogitsWarper(LogitsWarper):
|
||||||
|
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
top_a = float(top_a)
|
||||||
|
if top_a < 0 or top_a > 1.0:
|
||||||
|
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
||||||
|
self.top_a = top_a
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.filter_value >= 1.0:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
||||||
|
probs_max = probs[..., 0, None]
|
||||||
|
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
|
||||||
|
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep
|
||||||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
Loading…
Reference in New Issue