Merge branch 'avril' into rep-pen-order
This commit is contained in:
commit
74922966bd
|
@ -1727,8 +1727,6 @@ def patch_transformers():
|
|||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
|
@ -1805,6 +1803,7 @@ def patch_transformers():
|
|||
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||
for k in vars.sampler_order:
|
||||
|
@ -4617,7 +4616,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=int(2e9),
|
||||
repetition_penalty=1.1,
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=numseqs
|
||||
|
|
|
@ -176,7 +176,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
|||
logits[tokens] = penalty_logits
|
||||
return logits
|
||||
|
||||
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
|
@ -315,6 +315,7 @@ def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = Non
|
|||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
|
@ -362,7 +363,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
|||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
|
@ -500,6 +501,7 @@ def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None
|
|||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = apply_repetition_penalty_static(logits, *rpargs)
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||
|
||||
pad_token_id = 50256
|
||||
|
@ -513,17 +515,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample_dynamic to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key, num=2)
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
logits = apply_repetition_penalty_dynamic(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
|
@ -535,6 +526,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||
next_token = kobold_sample_dynamic(
|
||||
sample_key,
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
|
@ -606,18 +605,6 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
assert logits.shape == (1, config["n_vocab"])
|
||||
# Flatten it into a 1D array to make it easier to use
|
||||
logits = logits[0]
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
if repetition_penalty is not None:
|
||||
logits = apply_repetition_penalty_static(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
|
@ -629,6 +616,14 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
next_token = kobold_sample_static(
|
||||
sample_key,
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
|
|
|
@ -31,7 +31,7 @@ import torch
|
|||
from transformers import LogitsWarper, LogitsProcessor
|
||||
|
||||
|
||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue