Repetition penalty is now sampler #6 in the sampler order
This commit is contained in:
parent
9eecb61fea
commit
6ffaf43548
20
aiserver.py
20
aiserver.py
|
@ -1806,7 +1806,10 @@ def patch_transformers():
|
||||||
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||||
for k in vars.sampler_order:
|
sampler_order = vars.sampler_order[:]
|
||||||
|
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||||
|
sampler_order = [6] + sampler_order
|
||||||
|
for k in sampler_order:
|
||||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@ -1939,7 +1942,7 @@ def reset_model_settings():
|
||||||
vars.badwordsids = []
|
vars.badwordsids = []
|
||||||
vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
||||||
vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
|
vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
|
||||||
vars.sampler_order = [0, 1, 2, 3, 4, 5]
|
vars.sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||||
vars.newlinemode = "n"
|
vars.newlinemode = "n"
|
||||||
vars.revision = None
|
vars.revision = None
|
||||||
|
|
||||||
|
@ -2550,8 +2553,11 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
vars.compiling = False
|
vars.compiling = False
|
||||||
|
|
||||||
def tpumtjgenerate_settings_callback() -> dict:
|
def tpumtjgenerate_settings_callback() -> dict:
|
||||||
|
sampler_order = vars.sampler_order[:]
|
||||||
|
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||||
|
sampler_order = [6] + sampler_order
|
||||||
return {
|
return {
|
||||||
"sampler_order": vars.sampler_order,
|
"sampler_order": sampler_order,
|
||||||
"top_p": float(vars.top_p),
|
"top_p": float(vars.top_p),
|
||||||
"temp": float(vars.temp),
|
"temp": float(vars.temp),
|
||||||
"top_k": int(vars.top_k),
|
"top_k": int(vars.top_k),
|
||||||
|
@ -3658,12 +3664,16 @@ def get_message(msg):
|
||||||
sendUSStatItems()
|
sendUSStatItems()
|
||||||
elif(msg['cmd'] == 'samplers'):
|
elif(msg['cmd'] == 'samplers'):
|
||||||
sampler_order = msg["data"]
|
sampler_order = msg["data"]
|
||||||
|
sampler_order_min_length = 6
|
||||||
|
sampler_order_max_length = 7
|
||||||
if(not isinstance(sampler_order, list)):
|
if(not isinstance(sampler_order, list)):
|
||||||
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
|
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
|
||||||
if(len(sampler_order) != len(vars.sampler_order)):
|
if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)):
|
||||||
raise ValueError(f"Sampler order must be a list of length {len(vars.sampler_order)}, but got a list of length {len(sampler_order)}")
|
raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}")
|
||||||
if(not all(isinstance(e, int) for e in sampler_order)):
|
if(not all(isinstance(e, int) for e in sampler_order)):
|
||||||
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
|
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
|
||||||
|
if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)):
|
||||||
|
raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers")
|
||||||
vars.sampler_order = sampler_order
|
vars.sampler_order = sampler_order
|
||||||
settingschanged()
|
settingschanged()
|
||||||
elif(msg['cmd'] == 'list_model'):
|
elif(msg['cmd'] == 'list_model'):
|
||||||
|
|
|
@ -312,10 +312,10 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||||
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
||||||
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||||
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||||
|
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
|
||||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||||
|
|
||||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||||
|
@ -498,10 +498,10 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), apply_repetition_penalty_static, lambda x, *_: x, logits, *rpargs)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
logits = apply_repetition_penalty_static(logits, *rpargs)
|
|
||||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||||
|
|
||||||
pad_token_id = 50256
|
pad_token_id = 50256
|
||||||
|
@ -858,6 +858,9 @@ def infer_static(
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
if sampler_order is None:
|
if sampler_order is None:
|
||||||
sampler_order = utils.default_sampler_order.copy()
|
sampler_order = utils.default_sampler_order.copy()
|
||||||
|
sampler_order = sampler_order[:]
|
||||||
|
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
|
||||||
|
sampler_order = [6] + sampler_order
|
||||||
sampler_order = np.uint32(sampler_order)
|
sampler_order = np.uint32(sampler_order)
|
||||||
total_batch = 1
|
total_batch = 1
|
||||||
tokens = context
|
tokens = context
|
||||||
|
|
2
utils.py
2
utils.py
|
@ -33,7 +33,7 @@ layers_module_names: Optional[List[str]] = None
|
||||||
module_names: Optional[List[str]] = None
|
module_names: Optional[List[str]] = None
|
||||||
named_buffers: Optional[List[tuple]] = None
|
named_buffers: Optional[List[tuple]] = None
|
||||||
|
|
||||||
default_sampler_order = [0, 1, 2, 3, 4, 5]
|
default_sampler_order = [6, 0, 1, 2, 3, 4, 5]
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
|
|
Loading…
Reference in New Issue