Dynamic Fix

This commit is contained in:
somebody
2023-04-28 18:12:39 -05:00
parent 455b8257a9
commit bfef79d2b8

View File

@@ -215,12 +215,17 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
before picking one token using the modified logits before picking one token using the modified logits
''' '''
for sid in jnp.array(sampler_order, int): for sid in jnp.array(sampler_order, int):
# sid = int(sid) sid = int(sid)
sid = sid.astype(int)
warper = warpers.Warper.from_id(sid) warper = warpers.Warper.from_id(sid)
if not warper.value_is_valid(): if not warper.value_is_valid():
continue continue
logits = warper.jax_dynamic(logits)
# Repetition Penalty needs more info about the context
if warper == warpers.RepetitionPenalty:
logits = warper.jax_dynamic(logits, *rpargs)
else:
logits = warper.jax_dynamic(logits)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
@@ -432,11 +437,7 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
logits, logits,
( (
generated, generated,
# repetition_penalty,
generated_index, generated_index,
gen_length,
# rpslope,
# rprange,
), ),
**sampler_options, **sampler_options,
) )