diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index effb3de0..29ac4b42 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -533,7 +533,7 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ gen_length, rpslope, rprange, - ) + ), **sampler_options, ) # Remember what token was picked