mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge commit '8a38b258f497281af06fcb0c2559f382b419b938' into overhaul-merge
This commit is contained in:
23
aiserver.py
23
aiserver.py
@ -220,6 +220,7 @@ class vars:
|
||||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
top_a = 0.0 # Default generator top-a
|
||||
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||
typical = 1.0 # Default generator typical sampling threshold
|
||||
numseqs = 1 # Number of sequences to ask the generator to create
|
||||
@ -657,6 +658,8 @@ def loadmodelsettings():
|
||||
vars.tfs = js["tfs"]
|
||||
if("typical" in js):
|
||||
vars.typical = js["typical"]
|
||||
if("top_a" in js):
|
||||
vars.top_a = js["top_a"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
@ -693,6 +696,7 @@ def savesettings():
|
||||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["typical"] = vars.typical
|
||||
js["top_a"] = vars.top_a
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||
js["rep_pen_range"] = vars.rep_pen_range
|
||||
@ -773,6 +777,8 @@ def processsettings(js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("typical" in js):
|
||||
vars.typical = js["typical"]
|
||||
if("top_a" in js):
|
||||
vars.top_a = js["top_a"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
@ -1268,7 +1274,7 @@ def patch_transformers():
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||
old_call = cls.__call__
|
||||
@ -1288,6 +1294,7 @@ def patch_transformers():
|
||||
cls.__call__ = new_call
|
||||
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||
@ -1334,6 +1341,7 @@ def patch_transformers():
|
||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||
warper_list = LogitsProcessorList()
|
||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
@ -1962,6 +1970,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
||||
"top_k": int(vars.top_k),
|
||||
"tfs": float(vars.tfs),
|
||||
"typical": float(vars.typical),
|
||||
"top_a": float(vars.top_a),
|
||||
"repetition_penalty": float(vars.rep_pen),
|
||||
"rpslope": float(vars.rep_pen_slope),
|
||||
"rprange": int(vars.rep_pen_range),
|
||||
@ -2384,6 +2393,7 @@ def lua_has_setting(setting):
|
||||
"settopk",
|
||||
"settfs",
|
||||
"settypical",
|
||||
"settopa",
|
||||
"setreppen",
|
||||
"setreppenslope",
|
||||
"setreppenrange",
|
||||
@ -2403,6 +2413,7 @@ def lua_has_setting(setting):
|
||||
"top_k",
|
||||
"tfs",
|
||||
"typical",
|
||||
"topa",
|
||||
"reppen",
|
||||
"reppenslope",
|
||||
"reppenrange",
|
||||
@ -2437,6 +2448,7 @@ def lua_get_setting(setting):
|
||||
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
|
||||
if(setting in ("settfs", "tfs")): return vars.tfs
|
||||
if(setting in ("settypical", "typical")): return vars.typical
|
||||
if(setting in ("settopa", "topa")): return vars.top_a
|
||||
if(setting in ("setreppen", "reppen")): return vars.rep_pen
|
||||
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
|
||||
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
|
||||
@ -2472,6 +2484,7 @@ def lua_set_setting(setting, v):
|
||||
if(setting in ("settopk", "topk")): vars.top_k = v
|
||||
if(setting in ("settfs", "tfs")): vars.tfs = v
|
||||
if(setting in ("settypical", "typical")): vars.typical = v
|
||||
if(setting in ("settopa", "topa")): vars.top_a = v
|
||||
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
|
||||
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
|
||||
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
|
||||
@ -2862,6 +2875,11 @@ def get_message(msg):
|
||||
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'settopa'):
|
||||
vars.top_a = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setreppen'):
|
||||
vars.rep_pen = float(msg['data'])
|
||||
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
|
||||
@ -3988,6 +4006,7 @@ def sendtocolab(txt, min, max):
|
||||
'top_k': vars.top_k,
|
||||
'tfs': vars.tfs,
|
||||
'typical': vars.typical,
|
||||
'topa': vars.top_a,
|
||||
'numseqs': vars.numseqs,
|
||||
'retfultxt': False
|
||||
}
|
||||
@ -4125,6 +4144,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
top_k=vars.top_k,
|
||||
tfs=vars.tfs,
|
||||
typical=vars.typical,
|
||||
top_a=vars.top_a,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
rpslope=vars.rep_pen_slope,
|
||||
@ -4311,6 +4331,7 @@ def refresh_settings():
|
||||
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
|
||||
|
Reference in New Issue
Block a user