mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 20:20:45 +01:00
AvrilAI-style repetition penalty test
This commit is contained in:
parent
9356573ac9
commit
2db1f2f7bb
@ -722,8 +722,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
|
||||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
|
||||||
|
|
||||||
class LuaLogitsProcessor(LogitsProcessor):
|
class LuaLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
@ -767,6 +765,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||||
|
warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
||||||
return warper_list
|
return warper_list
|
||||||
|
|
||||||
def new_sample(self, *args, **kwargs):
|
def new_sample(self, *args, **kwargs):
|
||||||
@ -2771,7 +2770,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||||||
do_sample=True,
|
do_sample=True,
|
||||||
min_length=minimum,
|
min_length=minimum,
|
||||||
max_length=int(2e9),
|
max_length=int(2e9),
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.0,
|
||||||
bad_words_ids=vars.badwordsids,
|
bad_words_ids=vars.badwordsids,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_return_sequences=numseqs
|
num_return_sequences=numseqs
|
||||||
|
@ -149,7 +149,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
|||||||
logits[tokens] = penalty_logits
|
logits[tokens] = penalty_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
def kobold_sample_dynamic(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
This gets called by generate_loop_fn to apply a series of 4 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||||
@ -245,6 +245,7 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
|
logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||||
|
|
||||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||||
@ -292,7 +293,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
|||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
return logits.at[tokens].set(penalty_logits)
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
def kobold_sample_static(key, logits, rpargs, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
This gets called by generate_loop_fn to apply a series of 4 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||||
@ -387,6 +388,7 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
|
logits = apply_repetition_penalty_static(logits, *rpargs)
|
||||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||||
|
|
||||||
pad_token_id = 50256
|
pad_token_id = 50256
|
||||||
@ -400,17 +402,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||||||
# Get the pseudo-random number generator key that will
|
# Get the pseudo-random number generator key that will
|
||||||
# be used by kobold_sample_dynamic to randomly pick a token
|
# be used by kobold_sample_dynamic to randomly pick a token
|
||||||
sample_key, new_key = jax.random.split(sample_key, num=2)
|
sample_key, new_key = jax.random.split(sample_key, num=2)
|
||||||
# Apply repetition penalty to all tokens that are
|
|
||||||
# currently inside the "generated" array
|
|
||||||
logits = apply_repetition_penalty_dynamic(
|
|
||||||
logits,
|
|
||||||
generated,
|
|
||||||
repetition_penalty,
|
|
||||||
generated_index,
|
|
||||||
gen_length,
|
|
||||||
rpslope,
|
|
||||||
rprange,
|
|
||||||
)
|
|
||||||
# Remove any tokens in the badwords list by setting
|
# Remove any tokens in the badwords list by setting
|
||||||
# their logits to negative infinity which effectively
|
# their logits to negative infinity which effectively
|
||||||
# makes their probabilities of being chosen zero
|
# makes their probabilities of being chosen zero
|
||||||
@ -422,6 +413,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||||||
next_token = kobold_sample_dynamic(
|
next_token = kobold_sample_dynamic(
|
||||||
sample_key,
|
sample_key,
|
||||||
logits,
|
logits,
|
||||||
|
(
|
||||||
|
generated,
|
||||||
|
repetition_penalty,
|
||||||
|
generated_index,
|
||||||
|
gen_length,
|
||||||
|
rpslope,
|
||||||
|
rprange,
|
||||||
|
)
|
||||||
**sampler_options,
|
**sampler_options,
|
||||||
)
|
)
|
||||||
# Remember what token was picked
|
# Remember what token was picked
|
||||||
@ -493,18 +492,6 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
assert logits.shape == (1, config["n_vocab"])
|
assert logits.shape == (1, config["n_vocab"])
|
||||||
# Flatten it into a 1D array to make it easier to use
|
# Flatten it into a 1D array to make it easier to use
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
# Apply repetition penalty to all tokens that are
|
|
||||||
# currently inside the "generated" array
|
|
||||||
if repetition_penalty is not None:
|
|
||||||
logits = apply_repetition_penalty_static(
|
|
||||||
logits,
|
|
||||||
generated,
|
|
||||||
repetition_penalty,
|
|
||||||
generated_index,
|
|
||||||
gen_length,
|
|
||||||
rpslope,
|
|
||||||
rprange,
|
|
||||||
)
|
|
||||||
# Remove any tokens in the badwords list by setting
|
# Remove any tokens in the badwords list by setting
|
||||||
# their logits to negative infinity which effectively
|
# their logits to negative infinity which effectively
|
||||||
# makes their probabilities of being chosen zero
|
# makes their probabilities of being chosen zero
|
||||||
@ -516,6 +503,14 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
next_token = kobold_sample_static(
|
next_token = kobold_sample_static(
|
||||||
sample_key,
|
sample_key,
|
||||||
logits,
|
logits,
|
||||||
|
(
|
||||||
|
generated,
|
||||||
|
repetition_penalty,
|
||||||
|
generated_index,
|
||||||
|
gen_length,
|
||||||
|
rpslope,
|
||||||
|
rprange,
|
||||||
|
),
|
||||||
**sampler_options,
|
**sampler_options,
|
||||||
)
|
)
|
||||||
# Remember what token was picked
|
# Remember what token was picked
|
||||||
|
@ -31,7 +31,7 @@ import torch
|
|||||||
from transformers import LogitsWarper, LogitsProcessor
|
from transformers import LogitsWarper, LogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user