Merge branch 'avril' into rep-pen-order

This commit is contained in:
vfbd
2022-08-23 14:47:29 -04:00
3 changed files with 23 additions and 29 deletions

View File

@ -176,7 +176,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
logits[tokens] = penalty_logits
return logits
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
'''
This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
@ -315,6 +315,7 @@ def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = Non
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
@ -362,7 +363,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
'''
This gets called by generate_loop_fn to apply a series of 6 filters
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
@ -500,6 +501,7 @@ def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = apply_repetition_penalty_static(logits, *rpargs)
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
pad_token_id = 50256
@ -513,17 +515,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
# Get the pseudo-random number generator key that will
# be used by kobold_sample_dynamic to randomly pick a token
sample_key, new_key = jax.random.split(sample_key, num=2)
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
logits = apply_repetition_penalty_dynamic(
logits,
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
@ -535,6 +526,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
next_token = kobold_sample_dynamic(
sample_key,
logits,
(
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
**sampler_options,
)
# Remember what token was picked
@ -606,18 +605,6 @@ class PenalizingCausalTransformer(CausalTransformer):
assert logits.shape == (1, config["n_vocab"])
# Flatten it into a 1D array to make it easier to use
logits = logits[0]
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
if repetition_penalty is not None:
logits = apply_repetition_penalty_static(
logits,
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
@ -629,6 +616,14 @@ class PenalizingCausalTransformer(CausalTransformer):
next_token = kobold_sample_static(
sample_key,
logits,
(
generated,
repetition_penalty,
generated_index,
gen_length,
rpslope,
rprange,
),
**sampler_options,
)
# Remember what token was picked