mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: TPU should be ready for testing
This commit is contained in:
2
model.py
2
model.py
@@ -2203,8 +2203,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
|
|
||||||
# Get the model_type from the config or assume a model type if it isn't present
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
try:
|
try:
|
||||||
print("LMP:", self.get_local_model_path())
|
|
||||||
print("M:", utils.koboldai_vars.model)
|
|
||||||
model_config = AutoConfig.from_pretrained(
|
model_config = AutoConfig.from_pretrained(
|
||||||
self.get_local_model_path() or utils.koboldai_vars.model,
|
self.get_local_model_path() or utils.koboldai_vars.model,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
|
@@ -54,6 +54,8 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
|||||||
from mesh_transformer.util import to_bf16
|
from mesh_transformer.util import to_bf16
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import warpers
|
||||||
|
from warpers import Warper
|
||||||
|
|
||||||
socketio = None
|
socketio = None
|
||||||
|
|
||||||
@@ -213,6 +215,7 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
|||||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
before picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
|
"""
|
||||||
# Top-k (keep only the k tokens with the highest logits and remove
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
# the rest, by setting their logits to negative infinity)
|
# the rest, by setting their logits to negative infinity)
|
||||||
def top_k_filter(logits):
|
def top_k_filter(logits):
|
||||||
@@ -344,6 +347,18 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
|||||||
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||||
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||||
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||||
|
"""
|
||||||
|
for sid in sampler_order:
|
||||||
|
warper = Warper.from_id(sid)
|
||||||
|
if not warper.value_is_valid():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if warper == warpers.RepetitionPenalty:
|
||||||
|
print("ISREP", warper)
|
||||||
|
logits = warper.jax()
|
||||||
|
else:
|
||||||
|
print("AINTREP", warper)
|
||||||
|
logits = warper.jax_dynamic(logits, *rpargs)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
@@ -356,6 +371,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
|||||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
before picking one token using the modified logits
|
before picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
|
"""
|
||||||
# Top-k (keep only the k tokens with the highest logits and remove
|
# Top-k (keep only the k tokens with the highest logits and remove
|
||||||
# the rest, by setting their logits to negative infinity)
|
# the rest, by setting their logits to negative infinity)
|
||||||
def top_k_filter(logits):
|
def top_k_filter(logits):
|
||||||
@@ -486,6 +502,18 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
|||||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
|
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||||
|
"""
|
||||||
|
for sid in sampler_order:
|
||||||
|
warper = Warper.from_id(sid)
|
||||||
|
if not warper.value_is_valid():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if warper == warpers.RepetitionPenalty:
|
||||||
|
print("ISREP", warper)
|
||||||
|
logits = warper.jax()
|
||||||
|
else:
|
||||||
|
print("AINTREP", warper)
|
||||||
|
logits = warper.jax_static(logits, *rpargs)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
@@ -515,11 +543,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||||||
logits,
|
logits,
|
||||||
(
|
(
|
||||||
generated,
|
generated,
|
||||||
repetition_penalty,
|
# repetition_penalty,
|
||||||
generated_index,
|
generated_index,
|
||||||
gen_length,
|
gen_length,
|
||||||
rpslope,
|
# rpslope,
|
||||||
rprange,
|
# rprange,
|
||||||
),
|
),
|
||||||
**sampler_options,
|
**sampler_options,
|
||||||
)
|
)
|
||||||
@@ -605,11 +633,11 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
logits,
|
logits,
|
||||||
(
|
(
|
||||||
generated,
|
generated,
|
||||||
repetition_penalty,
|
# repetition_penalty,
|
||||||
generated_index,
|
generated_index,
|
||||||
gen_length,
|
# gen_length,
|
||||||
rpslope,
|
# rpslope,
|
||||||
rprange,
|
# rprange,
|
||||||
),
|
),
|
||||||
**sampler_options,
|
**sampler_options,
|
||||||
)
|
)
|
||||||
|
28
warpers.py
28
warpers.py
@@ -92,6 +92,10 @@ class Temperature(Warper):
|
|||||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
return scores / cls.temperature
|
return scores / cls.temperature
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.temperature != 1.0
|
||||||
|
|
||||||
|
|
||||||
class TopP(Warper):
|
class TopP(Warper):
|
||||||
"""
|
"""
|
||||||
@@ -140,6 +144,10 @@ class TopP(Warper):
|
|||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.top_p < 1.0
|
||||||
|
|
||||||
|
|
||||||
class TopK(Warper):
|
class TopK(Warper):
|
||||||
"""
|
"""
|
||||||
@@ -173,6 +181,10 @@ class TopK(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, scores)
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.top_p > 0
|
||||||
|
|
||||||
|
|
||||||
class TailFree(Warper):
|
class TailFree(Warper):
|
||||||
"""
|
"""
|
||||||
@@ -256,6 +268,10 @@ class TailFree(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, scores)
|
return np.where(indices_to_remove, -np.inf, scores)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.tfs < 1.0
|
||||||
|
|
||||||
|
|
||||||
class Typical(Warper):
|
class Typical(Warper):
|
||||||
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
|
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
|
||||||
@@ -332,6 +348,10 @@ class Typical(Warper):
|
|||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.typical < 1.0
|
||||||
|
|
||||||
|
|
||||||
class TopA(Warper):
|
class TopA(Warper):
|
||||||
"""
|
"""
|
||||||
@@ -370,6 +390,10 @@ class TopA(Warper):
|
|||||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.top_a > 0.0
|
||||||
|
|
||||||
|
|
||||||
class RepetitionPenalty(Warper):
|
class RepetitionPenalty(Warper):
|
||||||
rep_pen: float = 1.0
|
rep_pen: float = 1.0
|
||||||
@@ -543,3 +567,7 @@ class RepetitionPenalty(Warper):
|
|||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
scores[tokens] = penalty_logits
|
scores[tokens] = penalty_logits
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_is_valid(cls) -> bool:
|
||||||
|
return cls.rep_pen != 1.0
|
Reference in New Issue
Block a user