This commit is contained in:
Gnome Ann 2022-01-17 13:54:02 -05:00
parent 6502af086f
commit 31735c4239

View File

@ -74,6 +74,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
This gets called by generate_loop_fn to apply repetition penalty This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize to the 1D array logits using the provided 1D array of tokens to penalize
''' '''
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
# Make a new array with the same length as the tokens array but with # Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the # each element replaced by the value at the corresponding index in the
# logits array; e.g. # logits array; e.g.