mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: TPU should be ready for testing
This commit is contained in:
2
model.py
2
model.py
@@ -2203,8 +2203,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
try:
|
||||
print("LMP:", self.get_local_model_path())
|
||||
print("M:", utils.koboldai_vars.model)
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
self.get_local_model_path() or utils.koboldai_vars.model,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
|
@@ -54,6 +54,8 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
||||
from mesh_transformer.util import to_bf16
|
||||
import time
|
||||
|
||||
import warpers
|
||||
from warpers import Warper
|
||||
|
||||
socketio = None
|
||||
|
||||
@@ -213,6 +215,7 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
"""
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
# the rest, by setting their logits to negative infinity)
|
||||
def top_k_filter(logits):
|
||||
@@ -344,6 +347,18 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
"""
|
||||
for sid in sampler_order:
|
||||
warper = Warper.from_id(sid)
|
||||
if not warper.value_is_valid():
|
||||
continue
|
||||
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
print("ISREP", warper)
|
||||
logits = warper.jax()
|
||||
else:
|
||||
print("AINTREP", warper)
|
||||
logits = warper.jax_dynamic(logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
@@ -356,6 +371,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
"""
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
# the rest, by setting their logits to negative infinity)
|
||||
def top_k_filter(logits):
|
||||
@@ -486,6 +502,18 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||
"""
|
||||
for sid in sampler_order:
|
||||
warper = Warper.from_id(sid)
|
||||
if not warper.value_is_valid():
|
||||
continue
|
||||
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
print("ISREP", warper)
|
||||
logits = warper.jax()
|
||||
else:
|
||||
print("AINTREP", warper)
|
||||
logits = warper.jax_static(logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
@@ -515,11 +543,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
# repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
# rpslope,
|
||||
# rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
@@ -605,11 +633,11 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
logits,
|
||||
(
|
||||
generated,
|
||||
repetition_penalty,
|
||||
# repetition_penalty,
|
||||
generated_index,
|
||||
gen_length,
|
||||
rpslope,
|
||||
rprange,
|
||||
# gen_length,
|
||||
# rpslope,
|
||||
# rprange,
|
||||
),
|
||||
**sampler_options,
|
||||
)
|
||||
|
28
warpers.py
28
warpers.py
@@ -92,6 +92,10 @@ class Temperature(Warper):
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.temperature != 1.0
|
||||
|
||||
|
||||
class TopP(Warper):
|
||||
"""
|
||||
@@ -140,6 +144,10 @@ class TopP(Warper):
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_p < 1.0
|
||||
|
||||
|
||||
class TopK(Warper):
|
||||
"""
|
||||
@@ -173,6 +181,10 @@ class TopK(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_p > 0
|
||||
|
||||
|
||||
class TailFree(Warper):
|
||||
"""
|
||||
@@ -256,6 +268,10 @@ class TailFree(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.tfs < 1.0
|
||||
|
||||
|
||||
class Typical(Warper):
|
||||
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
|
||||
@@ -332,6 +348,10 @@ class Typical(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.typical < 1.0
|
||||
|
||||
|
||||
class TopA(Warper):
|
||||
"""
|
||||
@@ -370,6 +390,10 @@ class TopA(Warper):
|
||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_a > 0.0
|
||||
|
||||
|
||||
class RepetitionPenalty(Warper):
|
||||
rep_pen: float = 1.0
|
||||
@@ -543,3 +567,7 @@ class RepetitionPenalty(Warper):
|
||||
# positions in the logits array
|
||||
scores[tokens] = penalty_logits
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.rep_pen != 1.0
|
Reference in New Issue
Block a user