mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: TPU Fixes
This commit is contained in:
@@ -5519,6 +5519,7 @@ def final_startup():
|
||||
|
||||
# Precompile TPU backend if required
|
||||
if isinstance(model, HFMTJInferenceModel):
|
||||
import tpu_mtj_backend
|
||||
soft_tokens = model.get_soft_tokens()
|
||||
if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)):
|
||||
tpool.execute(tpu_mtj_backend.infer_dynamic, np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)),
|
||||
|
116
model.py
116
model.py
@@ -27,7 +27,6 @@ from warpers import Warper
|
||||
import torch
|
||||
from torch.nn import Embedding
|
||||
import numpy as np
|
||||
import accelerate.utils
|
||||
import transformers
|
||||
from transformers import (
|
||||
StoppingCriteria,
|
||||
@@ -48,6 +47,7 @@ import koboldai_settings
|
||||
|
||||
try:
|
||||
import breakmodel
|
||||
import accelerate.utils
|
||||
except ModuleNotFoundError as e:
|
||||
if not utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
@@ -889,7 +889,52 @@ class InferenceModel:
|
||||
hook(self, input_ids)
|
||||
|
||||
|
||||
class HFMTJInferenceModel(InferenceModel):
|
||||
class HFInferenceModel(InferenceModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model_config = None
|
||||
|
||||
def get_local_model_path(
|
||||
self, legacy: bool = False, ignore_existance: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns a string of the model's path locally, or None if it is not downloaded.
|
||||
If ignore_existance is true, it will always return a path.
|
||||
"""
|
||||
|
||||
basename = utils.koboldai_vars.model.replace("/", "_")
|
||||
if legacy:
|
||||
ret = basename
|
||||
else:
|
||||
ret = os.path.join("models", basename)
|
||||
|
||||
if os.path.isdir(ret) or ignore_existance:
|
||||
return ret
|
||||
return None
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
try:
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.get_local_model_path() or utils.koboldai_vars.model,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
utils.koboldai_vars.model_type = self.model_config.model_type
|
||||
except ValueError:
|
||||
utils.koboldai_vars.model_type = {
|
||||
"NeoCustom": "gpt_neo",
|
||||
"GPT2Custom": "gpt2",
|
||||
}.get(utils.koboldai_vars.model)
|
||||
|
||||
if not utils.koboldai_vars.model_type:
|
||||
logger.warning(
|
||||
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||
)
|
||||
utils.koboldai_vars.model_type = "gpt_neo"
|
||||
|
||||
|
||||
class HFMTJInferenceModel(HFInferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -1012,8 +1057,6 @@ class HFMTJInferenceModel(InferenceModel):
|
||||
"rprange": int(utils.koboldai_vars.rep_pen_range),
|
||||
}
|
||||
|
||||
self.load_mtj_backend()
|
||||
|
||||
tpu_mtj_backend.socketio = utils.socketio
|
||||
|
||||
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
|
||||
@@ -1045,21 +1088,20 @@ class HFMTJInferenceModel(InferenceModel):
|
||||
tpu_mtj_backend.settings_callback = mtj_settings_callback
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.patch_transformers()
|
||||
self.setup_mtj()
|
||||
|
||||
self.init_model_config()
|
||||
utils.koboldai_vars.allowsp = True
|
||||
# loadmodelsettings()
|
||||
# loadsettings()
|
||||
|
||||
tpu_mtj_backend.load_model(
|
||||
utils.koboldai_vars.custmodpth,
|
||||
utils.koboldai_vars.model,
|
||||
hf_checkpoint=utils.koboldai_vars.model
|
||||
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||
and utils.koboldai_vars.use_colab_tpu,
|
||||
socketio_queue=koboldai_settings.queue,
|
||||
initial_load=initial_load,
|
||||
logger=logger,
|
||||
**utils.koboldai_vars.modelconfig,
|
||||
**self.model_config.to_dict()
|
||||
# **utils.koboldai_vars.modelconfig,
|
||||
)
|
||||
|
||||
# tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
|
||||
@@ -1079,7 +1121,7 @@ class HFMTJInferenceModel(InferenceModel):
|
||||
if utils.koboldai_vars.newlinemode != "s" or str(k) != "</s>"
|
||||
]
|
||||
|
||||
def get_soft_tokens() -> np.array:
|
||||
def get_soft_tokens(self) -> np.array:
|
||||
soft_tokens = None
|
||||
|
||||
if utils.koboldai_vars.sp is None:
|
||||
@@ -1152,6 +1194,7 @@ class HFMTJInferenceModel(InferenceModel):
|
||||
genout = np.array(genout)
|
||||
|
||||
return GenerationResult(
|
||||
self,
|
||||
out_batches=genout,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=True,
|
||||
@@ -1159,7 +1202,7 @@ class HFMTJInferenceModel(InferenceModel):
|
||||
)
|
||||
|
||||
|
||||
class HFTorchInferenceModel(InferenceModel):
|
||||
class HFTorchInferenceModel(HFInferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -1181,7 +1224,6 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.model_config = None
|
||||
self.capabilties = ModelCapabilities(
|
||||
embedding_manipulation=True,
|
||||
post_token_hooks=True,
|
||||
@@ -1198,10 +1240,8 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
warper = Warper.from_id(sid)
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
# Rep pen needs more data than other samplers
|
||||
print("is rep:", warper)
|
||||
scores = warper.torch(scores, input_ids=input_ids)
|
||||
else:
|
||||
print("aint rep:", warper)
|
||||
scores = warper.torch(scores)
|
||||
return scores
|
||||
|
||||
@@ -1616,24 +1656,6 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
**tf_kwargs,
|
||||
)
|
||||
|
||||
def get_local_model_path(
|
||||
self, legacy: bool = False, ignore_existance: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns a string of the model's path locally, or None if it is not downloaded.
|
||||
If ignore_existance is true, it will always return a path.
|
||||
"""
|
||||
|
||||
basename = utils.koboldai_vars.model.replace("/", "_")
|
||||
if legacy:
|
||||
ret = basename
|
||||
else:
|
||||
ret = os.path.join("models", basename)
|
||||
|
||||
if os.path.isdir(ret) or ignore_existance:
|
||||
return ret
|
||||
return None
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.model.get_input_embeddings().embedding_dim
|
||||
|
||||
@@ -2206,25 +2228,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
self.get_local_model_path(ignore_existance=True),
|
||||
)
|
||||
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
self.get_local_model_path() or utils.koboldai_vars.model,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
utils.koboldai_vars.model_type = model_config.model_type
|
||||
except ValueError as e:
|
||||
utils.koboldai_vars.model_type = {
|
||||
"NeoCustom": "gpt_neo",
|
||||
"GPT2Custom": "gpt2",
|
||||
}.get(utils.koboldai_vars.model)
|
||||
|
||||
if not utils.koboldai_vars.model_type:
|
||||
logger.warning(
|
||||
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||
)
|
||||
utils.koboldai_vars.model_type = "gpt_neo"
|
||||
self.init_model_config()
|
||||
|
||||
tf_kwargs = {
|
||||
"low_cpu_mem_usage": True,
|
||||
@@ -2246,7 +2250,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
and utils.koboldai_vars.breakmodel
|
||||
and not utils.koboldai_vars.nobreakmodel
|
||||
):
|
||||
self.breakmodel_device_config(model_config)
|
||||
self.breakmodel_device_config(self.model_config)
|
||||
|
||||
if utils.koboldai_vars.lazy_load:
|
||||
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||
@@ -2254,9 +2258,9 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
dematerialized_modules=True, use_accelerate_init_empty_weights=True
|
||||
):
|
||||
try:
|
||||
metamodel = AutoModelForCausalLM.from_config(model_config)
|
||||
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||
except Exception as e:
|
||||
metamodel = GPTNeoForCausalLM.from_config(model_config)
|
||||
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
|
||||
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||
utils.module_names = list(metamodel.state_dict().keys())
|
||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||
@@ -2264,7 +2268,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
with self._maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(
|
||||
enable=utils.koboldai_vars.lazy_load,
|
||||
callback=self._get_lazy_load_callback(utils.num_layers(model_config))
|
||||
callback=self._get_lazy_load_callback(utils.num_layers(self.model_config))
|
||||
if utils.koboldai_vars.lazy_load
|
||||
else None,
|
||||
dematerialized_modules=True,
|
||||
|
@@ -55,7 +55,6 @@ from mesh_transformer.util import to_bf16
|
||||
import time
|
||||
|
||||
import warpers
|
||||
from warpers import Warper
|
||||
|
||||
socketio = None
|
||||
|
||||
@@ -215,150 +214,13 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
"""
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
# the rest, by setting their logits to negative infinity)
|
||||
def top_k_filter(logits):
|
||||
# After sorting the logits array in descending order,
|
||||
# sorted_indices_to_remove is a 1D array that is True for tokens
|
||||
# in the sorted logits array we want to remove and False for ones
|
||||
# we want to keep, in this case the first top_k elements will be
|
||||
# False and the rest will be True
|
||||
sorted_indices_to_remove = np.arange(len(logits)) >= top_k
|
||||
# Unsort the logits array back to its original configuration and
|
||||
# remove tokens we need to remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
np.argsort(-logits),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
# Top-a (remove all tokens that have softmax probability less than
|
||||
# a*m^2 where m is the maximum softmax probability)
|
||||
def top_a_filter(logits):
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = np.array(jax.nn.softmax(logits), copy=True)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
|
||||
# Top-p (after sorting the remaining tokens again in descending order of
|
||||
# logit, remove the ones that have cumulative softmax probability
|
||||
# greater than p)
|
||||
def top_p_filter(logits):
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
sorted_logits = -np.sort(-logits)
|
||||
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||
# probabilities
|
||||
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > top_p
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove[0] = False
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
np.argsort(-logits),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||
# except it's the "cumulative normalized absolute second finite
|
||||
# differences of the softmax probabilities" instead of just the
|
||||
# cumulative softmax probabilities)
|
||||
def tail_free_filter(logits):
|
||||
# Sort in descending order
|
||||
sorted_logits = -np.sort(-logits)
|
||||
# Softmax again
|
||||
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||
# Calculate the second finite differences of that array (i.e.
|
||||
# calculate the difference array and then calculate the difference
|
||||
# array of the difference array)
|
||||
d2 = np.diff(np.diff(probabilities))
|
||||
# Get the absolute values of all those second finite differences
|
||||
d2 = np.abs(d2)
|
||||
# Normalize (all elements in the array are divided by the sum of the
|
||||
# array's elements)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
# Get the prefix-sum array
|
||||
cumulative_d2 = np.cumsum(d2, axis=-1)
|
||||
# We will remove the tokens with a cumulative normalized absolute
|
||||
# second finite difference larger than the TFS value
|
||||
sorted_indices_to_remove = cumulative_d2 > tfs
|
||||
# Don't remove the token with the highest logit
|
||||
sorted_indices_to_remove[0] = False
|
||||
# Since the d2 array has two fewer elements than the logits array,
|
||||
# we'll add two extra Trues to the end
|
||||
sorted_indices_to_remove = np.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
np.argsort(-logits),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, logits)
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||
def typical_filter(logits):
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(logits)
|
||||
with np.errstate(divide="ignore"):
|
||||
log_probs = np.log(probs)
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
# Determine absolute difference between the negative entropy and the
|
||||
# log probabilities
|
||||
entropy_deviation = np.abs(neg_entropy - log_probs)
|
||||
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||
# kept tokens is the smallest possible value such that the sum of the
|
||||
# softmax probabilities of the kept tokens is at least the threshold
|
||||
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||
# and then keeping the smallest possible number of tokens from the
|
||||
# beginning such that sum of softmax probabilities is at or above the
|
||||
# threshold)
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
|
||||
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove[0] = False
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, logits)
|
||||
# Temperature (just divide the logits by the temperature)
|
||||
def temp_filter(logits):
|
||||
return logits / temp
|
||||
for k in sampler_order:
|
||||
if k == 0 and top_k > 0: logits = top_k_filter(logits)
|
||||
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
|
||||
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
|
||||
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
||||
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
|
||||
"""
|
||||
for sid in sampler_order:
|
||||
warper = Warper.from_id(sid)
|
||||
for sid in jnp.array(sampler_order, int):
|
||||
# sid = int(sid)
|
||||
warper = warpers.Warper.from_id(sid)
|
||||
if not warper.value_is_valid():
|
||||
continue
|
||||
logits = warper.jax_dynamic()
|
||||
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
print("ISREP", warper)
|
||||
logits = warper.jax()
|
||||
else:
|
||||
print("AINTREP", warper)
|
||||
logits = warper.jax_dynamic(logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
@@ -371,152 +233,14 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
|
||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||
before picking one token using the modified logits
|
||||
'''
|
||||
"""
|
||||
# Top-k (keep only the k tokens with the highest logits and remove
|
||||
# the rest, by setting their logits to negative infinity)
|
||||
def top_k_filter(logits):
|
||||
# After sorting the logits array in descending order,
|
||||
# sorted_indices_to_remove is a 1D array that is True for tokens
|
||||
# in the sorted logits array we want to remove and False for ones
|
||||
# we want to keep, in this case the first top_k elements will be
|
||||
# False and the rest will be True
|
||||
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k
|
||||
# Unsort the logits array back to its original configuration and
|
||||
# remove tokens we need to remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-logits),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
# Top-a (remove all tokens that have softmax probability less than
|
||||
# a*m^2 where m is the maximum softmax probability)
|
||||
def top_a_filter(logits):
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
probabilities = jax.nn.softmax(logits)
|
||||
# Find the largest probability
|
||||
probs_max = probabilities.max()
|
||||
# Remove tokens
|
||||
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
|
||||
# Top-p (after sorting the remaining tokens again in descending order of
|
||||
# logit, remove the ones that have cumulative softmax probability
|
||||
# greater than p)
|
||||
def top_p_filter(logits):
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
sorted_logits = -jnp.sort(-logits)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||
# probabilities
|
||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > top_p
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-logits),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||
# except it's the "cumulative normalized absolute second finite
|
||||
# differences of the softmax probabilities" instead of just the
|
||||
# cumulative softmax probabilities)
|
||||
def tail_free_filter(logits):
|
||||
# Sort in descending order
|
||||
sorted_logits = -jnp.sort(-logits)
|
||||
# Softmax again
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
# Calculate the second finite differences of that array (i.e.
|
||||
# calculate the difference array and then calculate the difference
|
||||
# array of the difference array)
|
||||
d2 = jnp.diff(jnp.diff(probabilities))
|
||||
# Get the absolute values of all those second finite differences
|
||||
d2 = jnp.abs(d2)
|
||||
# Normalize (all elements in the array are divided by the sum of the
|
||||
# array's elements)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
# Get the prefix-sum array
|
||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||
# We will remove the tokens with a cumulative normalized absolute
|
||||
# second finite difference larger than the TFS value
|
||||
sorted_indices_to_remove = cumulative_d2 > tfs
|
||||
# Don't remove the token with the highest logit
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Since the d2 array has two fewer elements than the logits array,
|
||||
# we'll add two extra Trues to the end
|
||||
sorted_indices_to_remove = jnp.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-logits),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||
def typical_filter(logits):
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(logits)
|
||||
log_probs = jnp.log(probs)
|
||||
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
|
||||
# in the set of softmax probabilities of the logits
|
||||
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
# Determine absolute difference between the negative entropy and the
|
||||
# log probabilities
|
||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||
# Keep certain tokens such that the sum of the entropy_deviation of the
|
||||
# kept tokens is the smallest possible value such that the sum of the
|
||||
# softmax probabilities of the kept tokens is at least the threshold
|
||||
# value (by sorting the tokens in ascending order of entropy_deviation
|
||||
# and then keeping the smallest possible number of tokens from the
|
||||
# beginning such that sum of softmax probabilities is at or above the
|
||||
# threshold)
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
|
||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||
# Temperature (just divide the logits by the temperature)
|
||||
def temp_filter(logits):
|
||||
return logits / temp
|
||||
for k in sampler_order:
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||
"""
|
||||
for sid in sampler_order:
|
||||
warper = Warper.from_id(sid)
|
||||
if not warper.value_is_valid():
|
||||
continue
|
||||
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
print("ISREP", warper)
|
||||
logits = warper.jax()
|
||||
else:
|
||||
print("AINTREP", warper)
|
||||
logits = warper.jax_static(logits, *rpargs)
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits)
|
||||
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs))
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||
|
||||
pad_token_id = 50256
|
||||
@@ -1101,6 +825,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
"tokenizer": "gpt2",
|
||||
}
|
||||
|
||||
|
||||
# Try to convert HF config.json to MTJ config
|
||||
if hf_checkpoint:
|
||||
spec_path = os.path.join("maps", koboldai_vars.model_type + ".json")
|
||||
@@ -1303,7 +1028,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
if socketio is None:
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
else:
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile())
|
||||
koboldai_vars.status_message = "Loading model"
|
||||
koboldai_vars.loaded_layers = 0
|
||||
koboldai_vars.total_layers = num_tensors
|
||||
|
2
utils.py
2
utils.py
@@ -8,7 +8,6 @@ from urllib.error import HTTPError
|
||||
import requests
|
||||
import requests.adapters
|
||||
import time
|
||||
import breakmodel
|
||||
from transformers import PreTrainedModel
|
||||
import packaging.version
|
||||
from tqdm.auto import tqdm
|
||||
@@ -653,6 +652,7 @@ def get_auxilary_device():
|
||||
if koboldai_vars.hascuda and koboldai_vars.usegpu:
|
||||
return koboldai_vars.gpu_device
|
||||
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
|
||||
import breakmodel
|
||||
return breakmodel.primary_device
|
||||
return "cpu"
|
||||
|
||||
|
110
warpers.py
110
warpers.py
@@ -46,8 +46,10 @@ try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tpu_mtj_backend
|
||||
except ImportError:
|
||||
assert not utils.koboldai_vars.use_colab_tpu
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
|
||||
def update_settings():
|
||||
@@ -89,7 +91,11 @@ class Temperature(Warper):
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
return scores / cls.temperature
|
||||
|
||||
@classmethod
|
||||
@@ -121,7 +127,31 @@ class TopP(Warper):
|
||||
return scores.masked_fill(indices_to_remove, -np.inf)
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
# new array
|
||||
sorted_logits = -np.sort(-scores)
|
||||
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||
# probabilities
|
||||
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove[0] = False
|
||||
# Unsort and remove
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
np.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
# Sort the logits array in descending order, replace every element
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
@@ -166,7 +196,7 @@ class TopK(Warper):
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
# After sorting the logits array in descending order,
|
||||
# sorted_indices_to_remove is a 1D array that is True for tokens
|
||||
# in the sorted logits array we want to remove and False for ones
|
||||
@@ -181,6 +211,16 @@ class TopK(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_p > 0
|
||||
@@ -225,7 +265,7 @@ class TailFree(Warper):
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
# Sort in descending order
|
||||
sorted_logits = -np.sort(-scores)
|
||||
|
||||
@@ -268,6 +308,30 @@ class TailFree(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -np.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
sorted_logits = -jnp.sort(-scores)
|
||||
probabilities = jax.nn.softmax(sorted_logits)
|
||||
|
||||
d2 = jnp.diff(jnp.diff(probabilities))
|
||||
d2 = jnp.abs(d2)
|
||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||
|
||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
||||
sorted_indices_to_remove = cumulative_d2 > cls.tfs
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
sorted_indices_to_remove = jnp.pad(
|
||||
sorted_indices_to_remove,
|
||||
(0, 2),
|
||||
constant_values=True,
|
||||
)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(-scores),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.tfs < 1.0
|
||||
@@ -315,7 +379,7 @@ class Typical(Warper):
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
# Compute softmax probabilities and the natural logarithms of them
|
||||
probs = jax.nn.softmax(scores)
|
||||
with np.errstate(divide="ignore"):
|
||||
@@ -348,6 +412,25 @@ class Typical(Warper):
|
||||
)
|
||||
return np.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
probs = jax.nn.softmax(scores)
|
||||
log_probs = jnp.log(probs)
|
||||
|
||||
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
|
||||
entropy_deviation = jnp.abs(neg_entropy - log_probs)
|
||||
|
||||
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
|
||||
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical
|
||||
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
|
||||
_, indices_to_remove = jax.lax.sort_key_val(
|
||||
jnp.argsort(entropy_deviation),
|
||||
sorted_indices_to_remove,
|
||||
)
|
||||
return jnp.where(indices_to_remove, -jnp.inf, scores)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.typical < 1.0
|
||||
@@ -377,7 +460,7 @@ class TopA(Warper):
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
# Replace every element in the logits array
|
||||
# with e (Euler's number) to the power of that element, and divide
|
||||
# each element of the new array by the sum of the elements in the
|
||||
@@ -390,6 +473,14 @@ class TopA(Warper):
|
||||
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
probabilities = jax.nn.softmax(scores)
|
||||
probs_max = probabilities.max()
|
||||
return jnp.where(
|
||||
probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.top_a > 0.0
|
||||
@@ -436,7 +527,6 @@ class RepetitionPenalty(Warper):
|
||||
return scores
|
||||
|
||||
@classmethod
|
||||
# def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
def jax_static(
|
||||
cls,
|
||||
logits: jnp.array,
|
||||
@@ -570,4 +660,4 @@ class RepetitionPenalty(Warper):
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
return cls.rep_pen != 1.0
|
||||
return cls.rep_pen != 1.0
|
||||
|
Reference in New Issue
Block a user