Model: TPU Fixes

This commit is contained in:
somebody
2023-02-27 19:29:55 -06:00
parent bd3bbdaad8
commit ef1155291f
5 changed files with 175 additions and 355 deletions

View File

@@ -5519,6 +5519,7 @@ def final_startup():
# Precompile TPU backend if required
if isinstance(model, HFMTJInferenceModel):
import tpu_mtj_backend
soft_tokens = model.get_soft_tokens()
if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)):
tpool.execute(tpu_mtj_backend.infer_dynamic, np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)),

116
model.py
View File

@@ -27,7 +27,6 @@ from warpers import Warper
import torch
from torch.nn import Embedding
import numpy as np
import accelerate.utils
import transformers
from transformers import (
StoppingCriteria,
@@ -48,6 +47,7 @@ import koboldai_settings
try:
import breakmodel
import accelerate.utils
except ModuleNotFoundError as e:
if not utils.koboldai_vars.use_colab_tpu:
raise e
@@ -889,7 +889,52 @@ class InferenceModel:
hook(self, input_ids)
class HFMTJInferenceModel(InferenceModel):
class HFInferenceModel(InferenceModel):
def __init__(self) -> None:
super().__init__()
self.model_config = None
def get_local_model_path(
self, legacy: bool = False, ignore_existance: bool = False
) -> Optional[str]:
"""
Returns a string of the model's path locally, or None if it is not downloaded.
If ignore_existance is true, it will always return a path.
"""
basename = utils.koboldai_vars.model.replace("/", "_")
if legacy:
ret = basename
else:
ret = os.path.join("models", basename)
if os.path.isdir(ret) or ignore_existance:
return ret
return None
def init_model_config(self) -> None:
# Get the model_type from the config or assume a model type if it isn't present
try:
self.model_config = AutoConfig.from_pretrained(
self.get_local_model_path() or utils.koboldai_vars.model,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
)
utils.koboldai_vars.model_type = self.model_config.model_type
except ValueError:
utils.koboldai_vars.model_type = {
"NeoCustom": "gpt_neo",
"GPT2Custom": "gpt2",
}.get(utils.koboldai_vars.model)
if not utils.koboldai_vars.model_type:
logger.warning(
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
)
utils.koboldai_vars.model_type = "gpt_neo"
class HFMTJInferenceModel(HFInferenceModel):
def __init__(
self,
model_name: str,
@@ -1012,8 +1057,6 @@ class HFMTJInferenceModel(InferenceModel):
"rprange": int(utils.koboldai_vars.rep_pen_range),
}
self.load_mtj_backend()
tpu_mtj_backend.socketio = utils.socketio
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
@@ -1045,21 +1088,20 @@ class HFMTJInferenceModel(InferenceModel):
tpu_mtj_backend.settings_callback = mtj_settings_callback
def _load(self, save_model: bool, initial_load: bool) -> None:
self.patch_transformers()
self.setup_mtj()
self.init_model_config()
utils.koboldai_vars.allowsp = True
# loadmodelsettings()
# loadsettings()
tpu_mtj_backend.load_model(
utils.koboldai_vars.custmodpth,
utils.koboldai_vars.model,
hf_checkpoint=utils.koboldai_vars.model
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
and utils.koboldai_vars.use_colab_tpu,
socketio_queue=koboldai_settings.queue,
initial_load=initial_load,
logger=logger,
**utils.koboldai_vars.modelconfig,
**self.model_config.to_dict()
# **utils.koboldai_vars.modelconfig,
)
# tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
@@ -1079,7 +1121,7 @@ class HFMTJInferenceModel(InferenceModel):
if utils.koboldai_vars.newlinemode != "s" or str(k) != "</s>"
]
def get_soft_tokens() -> np.array:
def get_soft_tokens(self) -> np.array:
soft_tokens = None
if utils.koboldai_vars.sp is None:
@@ -1152,6 +1194,7 @@ class HFMTJInferenceModel(InferenceModel):
genout = np.array(genout)
return GenerationResult(
self,
out_batches=genout,
prompt=prompt_tokens,
is_whole_generation=True,
@@ -1159,7 +1202,7 @@ class HFMTJInferenceModel(InferenceModel):
)
class HFTorchInferenceModel(InferenceModel):
class HFTorchInferenceModel(HFInferenceModel):
def __init__(
self,
model_name: str,
@@ -1181,7 +1224,6 @@ class HFTorchInferenceModel(InferenceModel):
self.model = None
self.tokenizer = None
self.model_config = None
self.capabilties = ModelCapabilities(
embedding_manipulation=True,
post_token_hooks=True,
@@ -1198,10 +1240,8 @@ class HFTorchInferenceModel(InferenceModel):
warper = Warper.from_id(sid)
if warper == warpers.RepetitionPenalty:
# Rep pen needs more data than other samplers
print("is rep:", warper)
scores = warper.torch(scores, input_ids=input_ids)
else:
print("aint rep:", warper)
scores = warper.torch(scores)
return scores
@@ -1616,24 +1656,6 @@ class HFTorchInferenceModel(InferenceModel):
**tf_kwargs,
)
def get_local_model_path(
self, legacy: bool = False, ignore_existance: bool = False
) -> Optional[str]:
"""
Returns a string of the model's path locally, or None if it is not downloaded.
If ignore_existance is true, it will always return a path.
"""
basename = utils.koboldai_vars.model.replace("/", "_")
if legacy:
ret = basename
else:
ret = os.path.join("models", basename)
if os.path.isdir(ret) or ignore_existance:
return ret
return None
def get_hidden_size(self) -> int:
return self.model.get_input_embeddings().embedding_dim
@@ -2206,25 +2228,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
self.get_local_model_path(ignore_existance=True),
)
# Get the model_type from the config or assume a model type if it isn't present
try:
model_config = AutoConfig.from_pretrained(
self.get_local_model_path() or utils.koboldai_vars.model,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
)
utils.koboldai_vars.model_type = model_config.model_type
except ValueError as e:
utils.koboldai_vars.model_type = {
"NeoCustom": "gpt_neo",
"GPT2Custom": "gpt2",
}.get(utils.koboldai_vars.model)
if not utils.koboldai_vars.model_type:
logger.warning(
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
)
utils.koboldai_vars.model_type = "gpt_neo"
self.init_model_config()
tf_kwargs = {
"low_cpu_mem_usage": True,
@@ -2246,7 +2250,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
and utils.koboldai_vars.breakmodel
and not utils.koboldai_vars.nobreakmodel
):
self.breakmodel_device_config(model_config)
self.breakmodel_device_config(self.model_config)
if utils.koboldai_vars.lazy_load:
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
@@ -2254,9 +2258,9 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
dematerialized_modules=True, use_accelerate_init_empty_weights=True
):
try:
metamodel = AutoModelForCausalLM.from_config(model_config)
metamodel = AutoModelForCausalLM.from_config(self.model_config)
except Exception as e:
metamodel = GPTNeoForCausalLM.from_config(model_config)
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
utils.layers_module_names = utils.get_layers_module_names(metamodel)
utils.module_names = list(metamodel.state_dict().keys())
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
@@ -2264,7 +2268,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
# Download model from Huggingface if it does not exist, otherwise load locally
with self._maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(
enable=utils.koboldai_vars.lazy_load,
callback=self._get_lazy_load_callback(utils.num_layers(model_config))
callback=self._get_lazy_load_callback(utils.num_layers(self.model_config))
if utils.koboldai_vars.lazy_load
else None,
dematerialized_modules=True,

View File

@@ -55,7 +55,6 @@ from mesh_transformer.util import to_bf16
import time
import warpers
from warpers import Warper
socketio = None
@@ -215,150 +214,13 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
"""
# Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity)
def top_k_filter(logits):
# After sorting the logits array in descending order,
# sorted_indices_to_remove is a 1D array that is True for tokens
# in the sorted logits array we want to remove and False for ones
# we want to keep, in this case the first top_k elements will be
# False and the rest will be True
sorted_indices_to_remove = np.arange(len(logits)) >= top_k
# Unsort the logits array back to its original configuration and
# remove tokens we need to remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-logits),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, logits)
# Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits):
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = np.array(jax.nn.softmax(logits), copy=True)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
# Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability
# greater than p)
def top_p_filter(logits):
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -np.sort(-logits)
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > top_p
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-logits),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, logits)
# Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the
# cumulative softmax probabilities)
def tail_free_filter(logits):
# Sort in descending order
sorted_logits = -np.sort(-logits)
# Softmax again
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate the second finite differences of that array (i.e.
# calculate the difference array and then calculate the difference
# array of the difference array)
d2 = np.diff(np.diff(probabilities))
# Get the absolute values of all those second finite differences
d2 = np.abs(d2)
# Normalize (all elements in the array are divided by the sum of the
# array's elements)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
# Get the prefix-sum array
cumulative_d2 = np.cumsum(d2, axis=-1)
# We will remove the tokens with a cumulative normalized absolute
# second finite difference larger than the TFS value
sorted_indices_to_remove = cumulative_d2 > tfs
# Don't remove the token with the highest logit
sorted_indices_to_remove[0] = False
# Since the d2 array has two fewer elements than the logits array,
# we'll add two extra Trues to the end
sorted_indices_to_remove = np.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-logits),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
with np.errstate(divide="ignore"):
log_probs = np.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = np.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -jnp.inf, logits)
# Temperature (just divide the logits by the temperature)
def temp_filter(logits):
return logits / temp
for k in sampler_order:
if k == 0 and top_k > 0: logits = top_k_filter(logits)
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
if k == 4 and typical < 1.0: logits = typical_filter(logits)
if k == 5 and temp != 1.0: logits = temp_filter(logits)
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
"""
for sid in sampler_order:
warper = Warper.from_id(sid)
for sid in jnp.array(sampler_order, int):
# sid = int(sid)
warper = warpers.Warper.from_id(sid)
if not warper.value_is_valid():
continue
logits = warper.jax_dynamic()
if warper == warpers.RepetitionPenalty:
print("ISREP", warper)
logits = warper.jax()
else:
print("AINTREP", warper)
logits = warper.jax_dynamic(logits, *rpargs)
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
@@ -371,152 +233,14 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
before picking one token using the modified logits
'''
"""
# Top-k (keep only the k tokens with the highest logits and remove
# the rest, by setting their logits to negative infinity)
def top_k_filter(logits):
# After sorting the logits array in descending order,
# sorted_indices_to_remove is a 1D array that is True for tokens
# in the sorted logits array we want to remove and False for ones
# we want to keep, in this case the first top_k elements will be
# False and the rest will be True
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k
# Unsort the logits array back to its original configuration and
# remove tokens we need to remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
# Top-a (remove all tokens that have softmax probability less than
# a*m^2 where m is the maximum softmax probability)
def top_a_filter(logits):
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
probabilities = jax.nn.softmax(logits)
# Find the largest probability
probs_max = probabilities.max()
# Remove tokens
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
# Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability
# greater than p)
def top_p_filter(logits):
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -jnp.sort(-logits)
probabilities = jax.nn.softmax(sorted_logits)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > top_p
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
# Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the
# cumulative softmax probabilities)
def tail_free_filter(logits):
# Sort in descending order
sorted_logits = -jnp.sort(-logits)
# Softmax again
probabilities = jax.nn.softmax(sorted_logits)
# Calculate the second finite differences of that array (i.e.
# calculate the difference array and then calculate the difference
# array of the difference array)
d2 = jnp.diff(jnp.diff(probabilities))
# Get the absolute values of all those second finite differences
d2 = jnp.abs(d2)
# Normalize (all elements in the array are divided by the sum of the
# array's elements)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
# Get the prefix-sum array
cumulative_d2 = jnp.cumsum(d2, axis=-1)
# We will remove the tokens with a cumulative normalized absolute
# second finite difference larger than the TFS value
sorted_indices_to_remove = cumulative_d2 > tfs
# Don't remove the token with the highest logit
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Since the d2 array has two fewer elements than the logits array,
# we'll add two extra Trues to the end
sorted_indices_to_remove = jnp.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
def typical_filter(logits):
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(logits)
log_probs = jnp.log(probs)
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
# in the set of softmax probabilities of the logits
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
# Determine absolute difference between the negative entropy and the
# log probabilities
entropy_deviation = jnp.abs(neg_entropy - log_probs)
# Keep certain tokens such that the sum of the entropy_deviation of the
# kept tokens is the smallest possible value such that the sum of the
# softmax probabilities of the kept tokens is at least the threshold
# value (by sorting the tokens in ascending order of entropy_deviation
# and then keeping the smallest possible number of tokens from the
# beginning such that sum of softmax probabilities is at or above the
# threshold)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
# Temperature (just divide the logits by the temperature)
def temp_filter(logits):
return logits / temp
for k in sampler_order:
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
"""
for sid in sampler_order:
warper = Warper.from_id(sid)
if not warper.value_is_valid():
continue
if warper == warpers.RepetitionPenalty:
print("ISREP", warper)
logits = warper.jax()
else:
print("AINTREP", warper)
logits = warper.jax_static(logits, *rpargs)
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits)
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs))
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
pad_token_id = 50256
@@ -1101,6 +825,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
"tokenizer": "gpt2",
}
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join("maps", koboldai_vars.model_type + ".json")
@@ -1303,7 +1028,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if socketio is None:
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
else:
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile())
koboldai_vars.status_message = "Loading model"
koboldai_vars.loaded_layers = 0
koboldai_vars.total_layers = num_tensors

View File

@@ -8,7 +8,6 @@ from urllib.error import HTTPError
import requests
import requests.adapters
import time
import breakmodel
from transformers import PreTrainedModel
import packaging.version
from tqdm.auto import tqdm
@@ -653,6 +652,7 @@ def get_auxilary_device():
if koboldai_vars.hascuda and koboldai_vars.usegpu:
return koboldai_vars.gpu_device
elif koboldai_vars.hascuda and koboldai_vars.breakmodel:
import breakmodel
return breakmodel.primary_device
return "cpu"

View File

@@ -46,8 +46,10 @@ try:
import jax
import jax.numpy as jnp
import tpu_mtj_backend
except ImportError:
assert not utils.koboldai_vars.use_colab_tpu
except ImportError as e:
print(e)
if utils.koboldai_vars.use_colab_tpu:
raise e
def update_settings():
@@ -89,7 +91,11 @@ class Temperature(Warper):
return scores / cls.temperature
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
def jax_dynamic(cls, scores: np.array) -> np.array:
return scores / cls.temperature
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
return scores / cls.temperature
@classmethod
@@ -121,7 +127,31 @@ class TopP(Warper):
return scores.masked_fill(indices_to_remove, -np.inf)
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
def jax_dynamic(cls, scores: np.array) -> np.array:
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -np.sort(-scores)
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
np.argsort(-scores),
sorted_indices_to_remove,
)
return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
# Sort the logits array in descending order, replace every element
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
@@ -166,7 +196,7 @@ class TopK(Warper):
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
def jax_dynamic(cls, scores: np.array) -> np.array:
# After sorting the logits array in descending order,
# sorted_indices_to_remove is a 1D array that is True for tokens
# in the sorted logits array we want to remove and False for ones
@@ -181,6 +211,16 @@ class TopK(Warper):
)
return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_p > 0
@@ -225,7 +265,7 @@ class TailFree(Warper):
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
def jax_dynamic(cls, scores: np.array) -> np.array:
# Sort in descending order
sorted_logits = -np.sort(-scores)
@@ -268,6 +308,30 @@ class TailFree(Warper):
)
return np.where(indices_to_remove, -np.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
sorted_logits = -jnp.sort(-scores)
probabilities = jax.nn.softmax(sorted_logits)
d2 = jnp.diff(jnp.diff(probabilities))
d2 = jnp.abs(d2)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
cumulative_d2 = jnp.cumsum(d2, axis=-1)
sorted_indices_to_remove = cumulative_d2 > cls.tfs
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
sorted_indices_to_remove = jnp.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-scores),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.tfs < 1.0
@@ -315,7 +379,7 @@ class Typical(Warper):
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
def jax_dynamic(cls, scores: np.array) -> np.array:
# Compute softmax probabilities and the natural logarithms of them
probs = jax.nn.softmax(scores)
with np.errstate(divide="ignore"):
@@ -348,6 +412,25 @@ class Typical(Warper):
)
return np.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
probs = jax.nn.softmax(scores)
log_probs = jnp.log(probs)
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
entropy_deviation = jnp.abs(neg_entropy - log_probs)
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(entropy_deviation),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, scores)
@classmethod
def value_is_valid(cls) -> bool:
return cls.typical < 1.0
@@ -377,7 +460,7 @@ class TopA(Warper):
return scores
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
def jax_dynamic(cls, scores: np.array) -> np.array:
# Replace every element in the logits array
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
@@ -390,6 +473,14 @@ class TopA(Warper):
probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores
)
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
probabilities = jax.nn.softmax(scores)
probs_max = probabilities.max()
return jnp.where(
probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores
)
@classmethod
def value_is_valid(cls) -> bool:
return cls.top_a > 0.0
@@ -436,7 +527,6 @@ class RepetitionPenalty(Warper):
return scores
@classmethod
# def jax_static(cls, scores: jnp.array) -> jnp.array:
def jax_static(
cls,
logits: jnp.array,
@@ -570,4 +660,4 @@ class RepetitionPenalty(Warper):
@classmethod
def value_is_valid(cls) -> bool:
return cls.rep_pen != 1.0
return cls.rep_pen != 1.0