diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index cbf24a02..07261636 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -215,12 +215,17 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra before picking one token using the modified logits ''' for sid in jnp.array(sampler_order, int): - # sid = int(sid) - sid = sid.astype(int) + sid = int(sid) warper = warpers.Warper.from_id(sid) + if not warper.value_is_valid(): continue - logits = warper.jax_dynamic(logits) + + # Repetition Penalty needs more info about the context + if warper == warpers.RepetitionPenalty: + logits = warper.jax_dynamic(logits, *rpargs) + else: + logits = warper.jax_dynamic(logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a @@ -432,11 +437,7 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ logits, ( generated, - # repetition_penalty, generated_index, - gen_length, - # rpslope, - # rprange, ), **sampler_options, )