mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Dynamic Fix
This commit is contained in:
@@ -215,12 +215,17 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
for sid in jnp.array(sampler_order, int):
|
||||
# sid = int(sid)
|
||||
sid = sid.astype(int)
|
||||
sid = int(sid)
|
||||
warper = warpers.Warper.from_id(sid)
|
||||
|
||||
if not warper.value_is_valid():
|
||||
continue
|
||||
logits = warper.jax_dynamic(logits)
|
||||
|
||||
# Repetition Penalty needs more info about the context
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
logits = warper.jax_dynamic(logits, *rpargs)
|
||||
else:
|
||||
logits = warper.jax_dynamic(logits)
|
||||
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
@@ -432,11 +437,7 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
# repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
# rpslope,
|
||||
# rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
|
Reference in New Issue
Block a user