Repetition penalty is now sampler #6 in the sampler order

This commit is contained in:
vfbd
2022-08-23 15:10:21 -04:00
parent 9eecb61fea
commit 6ffaf43548
3 changed files with 21 additions and 8 deletions

View File

@ -1806,7 +1806,10 @@ def patch_transformers():
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
for k in vars.sampler_order:
sampler_order = vars.sampler_order[:]
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
sampler_order = [6] + sampler_order
for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
return scores
@ -1939,7 +1942,7 @@ def reset_model_settings():
vars.badwordsids = []
vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
vars.sampler_order = [0, 1, 2, 3, 4, 5]
vars.sampler_order = [6, 0, 1, 2, 3, 4, 5]
vars.newlinemode = "n"
vars.revision = None
@ -2550,8 +2553,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
vars.compiling = False
def tpumtjgenerate_settings_callback() -> dict:
sampler_order = vars.sampler_order[:]
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
sampler_order = [6] + sampler_order
return {
"sampler_order": vars.sampler_order,
"sampler_order": sampler_order,
"top_p": float(vars.top_p),
"temp": float(vars.temp),
"top_k": int(vars.top_k),
@ -3658,12 +3664,16 @@ def get_message(msg):
sendUSStatItems()
elif(msg['cmd'] == 'samplers'):
sampler_order = msg["data"]
sampler_order_min_length = 6
sampler_order_max_length = 7
if(not isinstance(sampler_order, list)):
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
if(len(sampler_order) != len(vars.sampler_order)):
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)):
raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}")
if(not all(isinstance(e, int) for e in sampler_order)):
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)):
raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers")
vars.sampler_order = sampler_order
settingschanged()
elif(msg['cmd'] == 'list_model'):