mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' into mkultra
This commit is contained in:
35
aiserver.py
35
aiserver.py
@@ -964,7 +964,10 @@ def loadmodelsettings():
|
||||
if("nobreakmodel" in js):
|
||||
vars.nobreakmodel = js["nobreakmodel"]
|
||||
if("sampler_order" in js):
|
||||
vars.sampler_order = js["sampler_order"]
|
||||
sampler_order = vars.sampler_order
|
||||
if(len(sampler_order) < 7):
|
||||
sampler_order = [6] + sampler_order
|
||||
vars.sampler_order = sampler_order
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
@@ -1095,7 +1098,10 @@ def processsettings(js):
|
||||
if("andepth" in js):
|
||||
vars.andepth = js["andepth"]
|
||||
if("sampler_order" in js):
|
||||
vars.sampler_order = js["sampler_order"]
|
||||
sampler_order = vars.sampler_order
|
||||
if(len(sampler_order) < 7):
|
||||
sampler_order = [6] + sampler_order
|
||||
vars.sampler_order = sampler_order
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
@@ -1732,8 +1738,6 @@ def patch_transformers():
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
@@ -1810,9 +1814,13 @@ def patch_transformers():
|
||||
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||
for k in vars.sampler_order:
|
||||
sampler_order = vars.sampler_order[:]
|
||||
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
for k in sampler_order:
|
||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
return scores
|
||||
|
||||
@@ -1945,7 +1953,7 @@ def reset_model_settings():
|
||||
vars.badwordsids = []
|
||||
vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
||||
vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
|
||||
vars.sampler_order = [0, 1, 2, 3, 4, 5]
|
||||
vars.sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||
vars.newlinemode = "n"
|
||||
vars.revision = None
|
||||
|
||||
@@ -2558,8 +2566,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
vars.compiling = False
|
||||
|
||||
def tpumtjgenerate_settings_callback() -> dict:
|
||||
sampler_order = vars.sampler_order[:]
|
||||
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
return {
|
||||
"sampler_order": vars.sampler_order,
|
||||
"sampler_order": sampler_order,
|
||||
"top_p": float(vars.top_p),
|
||||
"temp": float(vars.temp),
|
||||
"top_k": int(vars.top_k),
|
||||
@@ -3666,12 +3677,16 @@ def get_message(msg):
|
||||
sendUSStatItems()
|
||||
elif(msg['cmd'] == 'samplers'):
|
||||
sampler_order = msg["data"]
|
||||
sampler_order_min_length = 6
|
||||
sampler_order_max_length = 7
|
||||
if(not isinstance(sampler_order, list)):
|
||||
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
|
||||
if(len(sampler_order) != len(vars.sampler_order)):
|
||||
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
|
||||
if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)):
|
||||
raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}")
|
||||
if(not all(isinstance(e, int) for e in sampler_order)):
|
||||
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
|
||||
if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)):
|
||||
raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers")
|
||||
vars.sampler_order = sampler_order
|
||||
settingschanged()
|
||||
elif(msg['cmd'] == 'list_model'):
|
||||
@@ -4624,7 +4639,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=int(2e9),
|
||||
repetition_penalty=1.1,
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=numseqs
|
||||
|
Reference in New Issue
Block a user