This commit is contained in:
Gnome Ann 2022-01-17 13:54:02 -05:00
parent 6502af086f
commit 31735c4239

View File

@ -74,6 +74,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty):
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
'''
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.