Model: TPU should be ready for testing

This commit is contained in:
somebody
2023-02-27 19:08:44 -06:00
parent 1839de1483
commit ed83362dee
3 changed files with 63 additions and 9 deletions

View File

@@ -2203,8 +2203,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
# Get the model_type from the config or assume a model type if it isn't present # Get the model_type from the config or assume a model type if it isn't present
try: try:
print("LMP:", self.get_local_model_path())
print("M:", utils.koboldai_vars.model)
model_config = AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(
self.get_local_model_path() or utils.koboldai_vars.model, self.get_local_model_path() or utils.koboldai_vars.model,
revision=utils.koboldai_vars.revision, revision=utils.koboldai_vars.revision,

View File

@@ -54,6 +54,8 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
from mesh_transformer.util import to_bf16 from mesh_transformer.util import to_bf16
import time import time
import warpers
from warpers import Warper
socketio = None socketio = None
@@ -213,6 +215,7 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits before picking one token using the modified logits
''' '''
"""
# Top-k (keep only the k tokens with the highest logits and remove # Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity) # the rest, by setting their logits to negative infinity)
def top_k_filter(logits): def top_k_filter(logits):
@@ -344,6 +347,18 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
if k == 4 and typical < 1.0: logits = typical_filter(logits) if k == 4 and typical < 1.0: logits = typical_filter(logits)
if k == 5 and temp != 1.0: logits = temp_filter(logits) if k == 5 and temp != 1.0: logits = temp_filter(logits)
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs) if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
"""
for sid in sampler_order:
warper = Warper.from_id(sid)
if not warper.value_is_valid():
continue
if warper == warpers.RepetitionPenalty:
print("ISREP", warper)
logits = warper.jax()
else:
print("AINTREP", warper)
logits = warper.jax_dynamic(logits, *rpargs)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
@@ -356,6 +371,7 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits before picking one token using the modified logits
''' '''
"""
# Top-k (keep only the k tokens with the highest logits and remove # Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity) # the rest, by setting their logits to negative infinity)
def top_k_filter(logits): def top_k_filter(logits):
@@ -486,6 +502,18 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs)) logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
"""
for sid in sampler_order:
warper = Warper.from_id(sid)
if not warper.value_is_valid():
continue
if warper == warpers.RepetitionPenalty:
print("ISREP", warper)
logits = warper.jax()
else:
print("AINTREP", warper)
logits = warper.jax_static(logits, *rpargs)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
@@ -515,11 +543,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
logits, logits,
( (
generated, generated,
repetition_penalty, # repetition_penalty,
generated_index, generated_index,
gen_length, gen_length,
rpslope, # rpslope,
rprange, # rprange,
), ),
**sampler_options, **sampler_options,
) )
@@ -605,11 +633,11 @@ class PenalizingCausalTransformer(CausalTransformer):
logits, logits,
( (
generated, generated,
repetition_penalty, # repetition_penalty,
generated_index, generated_index,
gen_length, # gen_length,
rpslope, # rpslope,
rprange, # rprange,
), ),
**sampler_options, **sampler_options,
) )

View File

@@ -92,6 +92,10 @@ class Temperature(Warper):
def jax(cls, scores: jnp.array) -> jnp.array: def jax(cls, scores: jnp.array) -> jnp.array:
return scores / cls.temperature return scores / cls.temperature
@classmethod
def value_is_valid(cls) -> bool:
return cls.temperature != 1.0
class TopP(Warper): class TopP(Warper):
""" """
@@ -140,6 +144,10 @@ class TopP(Warper):
) )
return jnp.where(indices_to_remove, -jnp.inf, scores) return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_p < 1.0
class TopK(Warper): class TopK(Warper):
""" """
@@ -173,6 +181,10 @@ class TopK(Warper):
) )
return np.where(indices_to_remove, -np.inf, scores) return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_p > 0
class TailFree(Warper): class TailFree(Warper):
""" """
@@ -256,6 +268,10 @@ class TailFree(Warper):
) )
return np.where(indices_to_remove, -np.inf, scores) return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.tfs < 1.0
class Typical(Warper): class Typical(Warper):
"""Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf""" """Typical sampling, described in https://arxiv.org/pdf/2202.00666.pdf"""
@@ -332,6 +348,10 @@ class Typical(Warper):
) )
return np.where(indices_to_remove, -jnp.inf, scores) return np.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.typical < 1.0
class TopA(Warper): class TopA(Warper):
""" """
@@ -370,6 +390,10 @@ class TopA(Warper):
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
) )
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_a > 0.0
class RepetitionPenalty(Warper): class RepetitionPenalty(Warper):
rep_pen: float = 1.0 rep_pen: float = 1.0
@@ -543,3 +567,7 @@ class RepetitionPenalty(Warper):
# positions in the logits array # positions in the logits array
scores[tokens] = penalty_logits scores[tokens] = penalty_logits
return scores return scores
@classmethod
def value_is_valid(cls) -> bool:
return cls.rep_pen != 1.0