mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'avril' into rep-pen-order
This commit is contained in:
@ -176,7 +176,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
||||
logits[tokens] = penalty_logits
|
||||
return logits
|
||||
|
||||
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
@ -315,6 +315,7 @@ def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = Non
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
@ -362,7 +363,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
@ -500,6 +501,7 @@ def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = apply_repetition_penalty_static(logits, *rpargs)
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||
|
||||
pad_token_id = 50256
|
||||
@ -513,17 +515,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample_dynamic to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key, num=2)
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
logits = apply_repetition_penalty_dynamic(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
@ -535,6 +526,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||
next_token = kobold_sample_dynamic(
|
||||
sample_key,
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
@ -606,18 +605,6 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
assert logits.shape == (1, config["n_vocab"])
|
||||
# Flatten it into a 1D array to make it easier to use
|
||||
logits = logits[0]
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
if repetition_penalty is not None:
|
||||
logits = apply_repetition_penalty_static(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
@ -629,6 +616,14 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
next_token = kobold_sample_static(
|
||||
sample_key,
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
|
Reference in New Issue
Block a user