From ef1155291f5c8d11a2db182f55e66470fd2463ba Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 27 Feb 2023 19:29:55 -0600 Subject: [PATCH] Model: TPU Fixes --- aiserver.py | 1 + model.py | 116 ++++++++--------- tpu_mtj_backend.py | 301 ++------------------------------------------- utils.py | 2 +- warpers.py | 110 +++++++++++++++-- 5 files changed, 175 insertions(+), 355 deletions(-) diff --git a/aiserver.py b/aiserver.py index dc8c417b..258bb109 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5519,6 +5519,7 @@ def final_startup(): # Precompile TPU backend if required if isinstance(model, HFMTJInferenceModel): + import tpu_mtj_backend soft_tokens = model.get_soft_tokens() if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)): tpool.execute(tpu_mtj_backend.infer_dynamic, np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)), diff --git a/model.py b/model.py index 44180341..bcc30cd2 100644 --- a/model.py +++ b/model.py @@ -27,7 +27,6 @@ from warpers import Warper import torch from torch.nn import Embedding import numpy as np -import accelerate.utils import transformers from transformers import ( StoppingCriteria, @@ -48,6 +47,7 @@ import koboldai_settings try: import breakmodel + import accelerate.utils except ModuleNotFoundError as e: if not utils.koboldai_vars.use_colab_tpu: raise e @@ -889,7 +889,52 @@ class InferenceModel: hook(self, input_ids) -class HFMTJInferenceModel(InferenceModel): +class HFInferenceModel(InferenceModel): + def __init__(self) -> None: + super().__init__() + self.model_config = None + + def get_local_model_path( + self, legacy: bool = False, ignore_existance: bool = False + ) -> Optional[str]: + """ + Returns a string of the model's path locally, or None if it is not downloaded. + If ignore_existance is true, it will always return a path. + """ + + basename = utils.koboldai_vars.model.replace("/", "_") + if legacy: + ret = basename + else: + ret = os.path.join("models", basename) + + if os.path.isdir(ret) or ignore_existance: + return ret + return None + + def init_model_config(self) -> None: + # Get the model_type from the config or assume a model type if it isn't present + try: + self.model_config = AutoConfig.from_pretrained( + self.get_local_model_path() or utils.koboldai_vars.model, + revision=utils.koboldai_vars.revision, + cache_dir="cache", + ) + utils.koboldai_vars.model_type = self.model_config.model_type + except ValueError: + utils.koboldai_vars.model_type = { + "NeoCustom": "gpt_neo", + "GPT2Custom": "gpt2", + }.get(utils.koboldai_vars.model) + + if not utils.koboldai_vars.model_type: + logger.warning( + "No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)" + ) + utils.koboldai_vars.model_type = "gpt_neo" + + +class HFMTJInferenceModel(HFInferenceModel): def __init__( self, model_name: str, @@ -1012,8 +1057,6 @@ class HFMTJInferenceModel(InferenceModel): "rprange": int(utils.koboldai_vars.rep_pen_range), } - self.load_mtj_backend() - tpu_mtj_backend.socketio = utils.socketio if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX": @@ -1045,21 +1088,20 @@ class HFMTJInferenceModel(InferenceModel): tpu_mtj_backend.settings_callback = mtj_settings_callback def _load(self, save_model: bool, initial_load: bool) -> None: - self.patch_transformers() self.setup_mtj() - + self.init_model_config() utils.koboldai_vars.allowsp = True - # loadmodelsettings() - # loadsettings() + tpu_mtj_backend.load_model( - utils.koboldai_vars.custmodpth, + utils.koboldai_vars.model, hf_checkpoint=utils.koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and utils.koboldai_vars.use_colab_tpu, socketio_queue=koboldai_settings.queue, initial_load=initial_load, logger=logger, - **utils.koboldai_vars.modelconfig, + **self.model_config.to_dict() + # **utils.koboldai_vars.modelconfig, ) # tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig) @@ -1079,7 +1121,7 @@ class HFMTJInferenceModel(InferenceModel): if utils.koboldai_vars.newlinemode != "s" or str(k) != "" ] - def get_soft_tokens() -> np.array: + def get_soft_tokens(self) -> np.array: soft_tokens = None if utils.koboldai_vars.sp is None: @@ -1152,6 +1194,7 @@ class HFMTJInferenceModel(InferenceModel): genout = np.array(genout) return GenerationResult( + self, out_batches=genout, prompt=prompt_tokens, is_whole_generation=True, @@ -1159,7 +1202,7 @@ class HFMTJInferenceModel(InferenceModel): ) -class HFTorchInferenceModel(InferenceModel): +class HFTorchInferenceModel(HFInferenceModel): def __init__( self, model_name: str, @@ -1181,7 +1224,6 @@ class HFTorchInferenceModel(InferenceModel): self.model = None self.tokenizer = None - self.model_config = None self.capabilties = ModelCapabilities( embedding_manipulation=True, post_token_hooks=True, @@ -1198,10 +1240,8 @@ class HFTorchInferenceModel(InferenceModel): warper = Warper.from_id(sid) if warper == warpers.RepetitionPenalty: # Rep pen needs more data than other samplers - print("is rep:", warper) scores = warper.torch(scores, input_ids=input_ids) else: - print("aint rep:", warper) scores = warper.torch(scores) return scores @@ -1616,24 +1656,6 @@ class HFTorchInferenceModel(InferenceModel): **tf_kwargs, ) - def get_local_model_path( - self, legacy: bool = False, ignore_existance: bool = False - ) -> Optional[str]: - """ - Returns a string of the model's path locally, or None if it is not downloaded. - If ignore_existance is true, it will always return a path. - """ - - basename = utils.koboldai_vars.model.replace("/", "_") - if legacy: - ret = basename - else: - ret = os.path.join("models", basename) - - if os.path.isdir(ret) or ignore_existance: - return ret - return None - def get_hidden_size(self) -> int: return self.model.get_input_embeddings().embedding_dim @@ -2206,25 +2228,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): self.get_local_model_path(ignore_existance=True), ) - # Get the model_type from the config or assume a model type if it isn't present - try: - model_config = AutoConfig.from_pretrained( - self.get_local_model_path() or utils.koboldai_vars.model, - revision=utils.koboldai_vars.revision, - cache_dir="cache", - ) - utils.koboldai_vars.model_type = model_config.model_type - except ValueError as e: - utils.koboldai_vars.model_type = { - "NeoCustom": "gpt_neo", - "GPT2Custom": "gpt2", - }.get(utils.koboldai_vars.model) - - if not utils.koboldai_vars.model_type: - logger.warning( - "No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)" - ) - utils.koboldai_vars.model_type = "gpt_neo" + self.init_model_config() tf_kwargs = { "low_cpu_mem_usage": True, @@ -2246,7 +2250,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): and utils.koboldai_vars.breakmodel and not utils.koboldai_vars.nobreakmodel ): - self.breakmodel_device_config(model_config) + self.breakmodel_device_config(self.model_config) if utils.koboldai_vars.lazy_load: # If we're using lazy loader, we need to figure out what the model's hidden layers are called @@ -2254,9 +2258,9 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): dematerialized_modules=True, use_accelerate_init_empty_weights=True ): try: - metamodel = AutoModelForCausalLM.from_config(model_config) + metamodel = AutoModelForCausalLM.from_config(self.model_config) except Exception as e: - metamodel = GPTNeoForCausalLM.from_config(model_config) + metamodel = GPTNeoForCausalLM.from_config(self.model_config) utils.layers_module_names = utils.get_layers_module_names(metamodel) utils.module_names = list(metamodel.state_dict().keys()) utils.named_buffers = list(metamodel.named_buffers(recurse=True)) @@ -2264,7 +2268,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): # Download model from Huggingface if it does not exist, otherwise load locally with self._maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load( enable=utils.koboldai_vars.lazy_load, - callback=self._get_lazy_load_callback(utils.num_layers(model_config)) + callback=self._get_lazy_load_callback(utils.num_layers(self.model_config)) if utils.koboldai_vars.lazy_load else None, dematerialized_modules=True, diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 9a56ffd5..5afeff50 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -55,7 +55,6 @@ from mesh_transformer.util import to_bf16 import time import warpers -from warpers import Warper socketio = None @@ -215,150 +214,13 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' - """ - # Top-k (keep only the k tokens with the highest logits and remove - # the rest, by setting their logits to negative infinity) - def top_k_filter(logits): - # After sorting the logits array in descending order, - # sorted_indices_to_remove is a 1D array that is True for tokens - # in the sorted logits array we want to remove and False for ones - # we want to keep, in this case the first top_k elements will be - # False and the rest will be True - sorted_indices_to_remove = np.arange(len(logits)) >= top_k - # Unsort the logits array back to its original configuration and - # remove tokens we need to remove - _, indices_to_remove = jax.lax.sort_key_val( - np.argsort(-logits), - sorted_indices_to_remove, - ) - return np.where(indices_to_remove, -np.inf, logits) - # Top-a (remove all tokens that have softmax probability less than - # a*m^2 where m is the maximum softmax probability) - def top_a_filter(logits): - # Replace every element in the logits array - # with e (Euler's number) to the power of that element, and divide - # each element of the new array by the sum of the elements in the - # new array - probabilities = np.array(jax.nn.softmax(logits), copy=True) - # Find the largest probability - probs_max = probabilities.max() - # Remove tokens - return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) - # Top-p (after sorting the remaining tokens again in descending order of - # logit, remove the ones that have cumulative softmax probability - # greater than p) - def top_p_filter(logits): - # Sort the logits array in descending order, replace every element - # with e (Euler's number) to the power of that element, and divide - # each element of the new array by the sum of the elements in the - # new array - sorted_logits = -np.sort(-logits) - probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) - # Calculate cumulative_probabilities as the prefix-sum array of - # probabilities - cumulative_probabilities = np.cumsum(probabilities, axis=-1) - # We want to remove tokens with cumulative probability higher - # than top_p - sorted_indices_to_remove = cumulative_probabilities > top_p - # Don't ever remove the token with the highest logit, even if - # the probability is higher than top_p - sorted_indices_to_remove[0] = False - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - np.argsort(-logits), - sorted_indices_to_remove, - ) - return np.where(indices_to_remove, -np.inf, logits) - # Tail free sampling (basically top-p a second time on remaining tokens - # except it's the "cumulative normalized absolute second finite - # differences of the softmax probabilities" instead of just the - # cumulative softmax probabilities) - def tail_free_filter(logits): - # Sort in descending order - sorted_logits = -np.sort(-logits) - # Softmax again - probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) - # Calculate the second finite differences of that array (i.e. - # calculate the difference array and then calculate the difference - # array of the difference array) - d2 = np.diff(np.diff(probabilities)) - # Get the absolute values of all those second finite differences - d2 = np.abs(d2) - # Normalize (all elements in the array are divided by the sum of the - # array's elements) - d2 = d2 / d2.sum(axis=-1, keepdims=True) - # Get the prefix-sum array - cumulative_d2 = np.cumsum(d2, axis=-1) - # We will remove the tokens with a cumulative normalized absolute - # second finite difference larger than the TFS value - sorted_indices_to_remove = cumulative_d2 > tfs - # Don't remove the token with the highest logit - sorted_indices_to_remove[0] = False - # Since the d2 array has two fewer elements than the logits array, - # we'll add two extra Trues to the end - sorted_indices_to_remove = np.pad( - sorted_indices_to_remove, - (0, 2), - constant_values=True, - ) - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - np.argsort(-logits), - sorted_indices_to_remove, - ) - return np.where(indices_to_remove, -np.inf, logits) - # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) - def typical_filter(logits): - # Compute softmax probabilities and the natural logarithms of them - probs = jax.nn.softmax(logits) - with np.errstate(divide="ignore"): - log_probs = np.log(probs) - # Compute the negative of entropy, which is the sum of p*ln(p) for all p - # in the set of softmax probabilities of the logits - neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True) - # Determine absolute difference between the negative entropy and the - # log probabilities - entropy_deviation = np.abs(neg_entropy - log_probs) - # Keep certain tokens such that the sum of the entropy_deviation of the - # kept tokens is the smallest possible value such that the sum of the - # softmax probabilities of the kept tokens is at least the threshold - # value (by sorting the tokens in ascending order of entropy_deviation - # and then keeping the smallest possible number of tokens from the - # beginning such that sum of softmax probabilities is at or above the - # threshold) - _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) - sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical - sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1) - sorted_indices_to_remove[0] = False - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(entropy_deviation), - sorted_indices_to_remove, - ) - return np.where(indices_to_remove, -jnp.inf, logits) - # Temperature (just divide the logits by the temperature) - def temp_filter(logits): - return logits / temp - for k in sampler_order: - if k == 0 and top_k > 0: logits = top_k_filter(logits) - if k == 1 and top_a > 0.0: logits = top_a_filter(logits) - if k == 2 and top_p < 1.0: logits = top_p_filter(logits) - if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) - if k == 4 and typical < 1.0: logits = typical_filter(logits) - if k == 5 and temp != 1.0: logits = temp_filter(logits) - if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs) - """ - for sid in sampler_order: - warper = Warper.from_id(sid) + for sid in jnp.array(sampler_order, int): + # sid = int(sid) + warper = warpers.Warper.from_id(sid) if not warper.value_is_valid(): continue + logits = warper.jax_dynamic() - if warper == warpers.RepetitionPenalty: - print("ISREP", warper) - logits = warper.jax() - else: - print("AINTREP", warper) - logits = warper.jax_dynamic(logits, *rpargs) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -371,152 +233,14 @@ def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) before picking one token using the modified logits ''' - """ - # Top-k (keep only the k tokens with the highest logits and remove - # the rest, by setting their logits to negative infinity) - def top_k_filter(logits): - # After sorting the logits array in descending order, - # sorted_indices_to_remove is a 1D array that is True for tokens - # in the sorted logits array we want to remove and False for ones - # we want to keep, in this case the first top_k elements will be - # False and the rest will be True - sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k - # Unsort the logits array back to its original configuration and - # remove tokens we need to remove - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-logits), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - # Top-a (remove all tokens that have softmax probability less than - # a*m^2 where m is the maximum softmax probability) - def top_a_filter(logits): - # Replace every element in the logits array - # with e (Euler's number) to the power of that element, and divide - # each element of the new array by the sum of the elements in the - # new array - probabilities = jax.nn.softmax(logits) - # Find the largest probability - probs_max = probabilities.max() - # Remove tokens - return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) - # Top-p (after sorting the remaining tokens again in descending order of - # logit, remove the ones that have cumulative softmax probability - # greater than p) - def top_p_filter(logits): - # Sort the logits array in descending order, replace every element - # with e (Euler's number) to the power of that element, and divide - # each element of the new array by the sum of the elements in the - # new array - sorted_logits = -jnp.sort(-logits) - probabilities = jax.nn.softmax(sorted_logits) - # Calculate cumulative_probabilities as the prefix-sum array of - # probabilities - cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) - # We want to remove tokens with cumulative probability higher - # than top_p - sorted_indices_to_remove = cumulative_probabilities > top_p - # Don't ever remove the token with the highest logit, even if - # the probability is higher than top_p - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-logits), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - # Tail free sampling (basically top-p a second time on remaining tokens - # except it's the "cumulative normalized absolute second finite - # differences of the softmax probabilities" instead of just the - # cumulative softmax probabilities) - def tail_free_filter(logits): - # Sort in descending order - sorted_logits = -jnp.sort(-logits) - # Softmax again - probabilities = jax.nn.softmax(sorted_logits) - # Calculate the second finite differences of that array (i.e. - # calculate the difference array and then calculate the difference - # array of the difference array) - d2 = jnp.diff(jnp.diff(probabilities)) - # Get the absolute values of all those second finite differences - d2 = jnp.abs(d2) - # Normalize (all elements in the array are divided by the sum of the - # array's elements) - d2 = d2 / d2.sum(axis=-1, keepdims=True) - # Get the prefix-sum array - cumulative_d2 = jnp.cumsum(d2, axis=-1) - # We will remove the tokens with a cumulative normalized absolute - # second finite difference larger than the TFS value - sorted_indices_to_remove = cumulative_d2 > tfs - # Don't remove the token with the highest logit - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) - # Since the d2 array has two fewer elements than the logits array, - # we'll add two extra Trues to the end - sorted_indices_to_remove = jnp.pad( - sorted_indices_to_remove, - (0, 2), - constant_values=True, - ) - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-logits), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) - def typical_filter(logits): - # Compute softmax probabilities and the natural logarithms of them - probs = jax.nn.softmax(logits) - log_probs = jnp.log(probs) - # Compute the negative of entropy, which is the sum of p*ln(p) for all p - # in the set of softmax probabilities of the logits - neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True) - # Determine absolute difference between the negative entropy and the - # log probabilities - entropy_deviation = jnp.abs(neg_entropy - log_probs) - # Keep certain tokens such that the sum of the entropy_deviation of the - # kept tokens is the smallest possible value such that the sum of the - # softmax probabilities of the kept tokens is at least the threshold - # value (by sorting the tokens in ascending order of entropy_deviation - # and then keeping the smallest possible number of tokens from the - # beginning such that sum of softmax probabilities is at or above the - # threshold) - _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) - sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical - sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1) - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) - # Unsort and remove - _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(entropy_deviation), - sorted_indices_to_remove, - ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - # Temperature (just divide the logits by the temperature) - def temp_filter(logits): - return logits / temp for k in sampler_order: - logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) - logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs)) - """ - for sid in sampler_order: - warper = Warper.from_id(sid) - if not warper.value_is_valid(): - continue - - if warper == warpers.RepetitionPenalty: - print("ISREP", warper) - logits = warper.jax() - else: - print("AINTREP", warper) - logits = warper.jax_static(logits, *rpargs) - # Finally, pick one token using the softmax thingy again (it gives - # an array whose elements sum to 1 so it can be used nicely as a - # probability distribution) + logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), warpers.TopK.jax_static, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), warpers.TopA.jax_static, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), warpers.TopP.jax_static, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), warpers.TailFree.jax_static, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), warpers.Typical.jax_static, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), warpers.Temperature.jax_static, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: warpers.RepetitionPenalty.jax_static(*x), lambda x: x[0], (logits, *rpargs)) return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 @@ -1101,6 +825,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo "tokenizer": "gpt2", } + # Try to convert HF config.json to MTJ config if hf_checkpoint: spec_path = os.path.join("maps", koboldai_vars.model_type + ".json") @@ -1303,7 +1028,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if socketio is None: utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") else: - utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio()) + utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=utils.UIProgressBarFile()) koboldai_vars.status_message = "Loading model" koboldai_vars.loaded_layers = 0 koboldai_vars.total_layers = num_tensors diff --git a/utils.py b/utils.py index 1483d4d4..651269af 100644 --- a/utils.py +++ b/utils.py @@ -8,7 +8,6 @@ from urllib.error import HTTPError import requests import requests.adapters import time -import breakmodel from transformers import PreTrainedModel import packaging.version from tqdm.auto import tqdm @@ -653,6 +652,7 @@ def get_auxilary_device(): if koboldai_vars.hascuda and koboldai_vars.usegpu: return koboldai_vars.gpu_device elif koboldai_vars.hascuda and koboldai_vars.breakmodel: + import breakmodel return breakmodel.primary_device return "cpu" diff --git a/warpers.py b/warpers.py index 9c78ab2a..c63b34e2 100644 --- a/warpers.py +++ b/warpers.py @@ -46,8 +46,10 @@ try: import jax import jax.numpy as jnp import tpu_mtj_backend -except ImportError: - assert not utils.koboldai_vars.use_colab_tpu +except ImportError as e: + print(e) + if utils.koboldai_vars.use_colab_tpu: + raise e def update_settings(): @@ -89,7 +91,11 @@ class Temperature(Warper): return scores / cls.temperature @classmethod - def jax(cls, scores: jnp.array) -> jnp.array: + def jax_dynamic(cls, scores: np.array) -> np.array: + return scores / cls.temperature + + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: return scores / cls.temperature @classmethod @@ -121,7 +127,31 @@ class TopP(Warper): return scores.masked_fill(indices_to_remove, -np.inf) @classmethod - def jax(cls, scores: jnp.array) -> jnp.array: + def jax_dynamic(cls, scores: np.array) -> np.array: + # Sort the logits array in descending order, replace every element + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + sorted_logits = -np.sort(-scores) + probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) + # Calculate cumulative_probabilities as the prefix-sum array of + # probabilities + cumulative_probabilities = np.cumsum(probabilities, axis=-1) + # We want to remove tokens with cumulative probability higher + # than top_p + sorted_indices_to_remove = cumulative_probabilities > cls.top_p + # Don't ever remove the token with the highest logit, even if + # the probability is higher than top_p + sorted_indices_to_remove[0] = False + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + np.argsort(-scores), + sorted_indices_to_remove, + ) + return np.where(indices_to_remove, -np.inf, scores) + + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: # Sort the logits array in descending order, replace every element # with e (Euler's number) to the power of that element, and divide # each element of the new array by the sum of the elements in the @@ -166,7 +196,7 @@ class TopK(Warper): return scores @classmethod - def jax(cls, scores: jnp.array) -> jnp.array: + def jax_dynamic(cls, scores: np.array) -> np.array: # After sorting the logits array in descending order, # sorted_indices_to_remove is a 1D array that is True for tokens # in the sorted logits array we want to remove and False for ones @@ -181,6 +211,16 @@ class TopK(Warper): ) return np.where(indices_to_remove, -np.inf, scores) + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: + sorted_indices_to_remove = jnp.arange(len(scores)) >= cls.top_k + + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-scores), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + @classmethod def value_is_valid(cls) -> bool: return cls.top_p > 0 @@ -225,7 +265,7 @@ class TailFree(Warper): return scores @classmethod - def jax(cls, scores: jnp.array) -> jnp.array: + def jax_dynamic(cls, scores: np.array) -> np.array: # Sort in descending order sorted_logits = -np.sort(-scores) @@ -268,6 +308,30 @@ class TailFree(Warper): ) return np.where(indices_to_remove, -np.inf, scores) + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: + sorted_logits = -jnp.sort(-scores) + probabilities = jax.nn.softmax(sorted_logits) + + d2 = jnp.diff(jnp.diff(probabilities)) + d2 = jnp.abs(d2) + d2 = d2 / d2.sum(axis=-1, keepdims=True) + + cumulative_d2 = jnp.cumsum(d2, axis=-1) + sorted_indices_to_remove = cumulative_d2 > cls.tfs + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + sorted_indices_to_remove = jnp.pad( + sorted_indices_to_remove, + (0, 2), + constant_values=True, + ) + + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-scores), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + @classmethod def value_is_valid(cls) -> bool: return cls.tfs < 1.0 @@ -315,7 +379,7 @@ class Typical(Warper): return scores @classmethod - def jax(cls, scores: jnp.array) -> jnp.array: + def jax_dynamic(cls, scores: np.array) -> np.array: # Compute softmax probabilities and the natural logarithms of them probs = jax.nn.softmax(scores) with np.errstate(divide="ignore"): @@ -348,6 +412,25 @@ class Typical(Warper): ) return np.where(indices_to_remove, -jnp.inf, scores) + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: + probs = jax.nn.softmax(scores) + log_probs = jnp.log(probs) + + neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True) + entropy_deviation = jnp.abs(neg_entropy - log_probs) + + _, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs) + sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= cls.typical + sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1) + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(entropy_deviation), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, scores) + @classmethod def value_is_valid(cls) -> bool: return cls.typical < 1.0 @@ -377,7 +460,7 @@ class TopA(Warper): return scores @classmethod - def jax(cls, scores: jnp.array) -> jnp.array: + def jax_dynamic(cls, scores: np.array) -> np.array: # Replace every element in the logits array # with e (Euler's number) to the power of that element, and divide # each element of the new array by the sum of the elements in the @@ -390,6 +473,14 @@ class TopA(Warper): probabilities < probs_max * probs_max * cls.top_a, -np.inf, scores ) + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: + probabilities = jax.nn.softmax(scores) + probs_max = probabilities.max() + return jnp.where( + probabilities < probs_max * probs_max * cls.top_a, -jnp.inf, scores + ) + @classmethod def value_is_valid(cls) -> bool: return cls.top_a > 0.0 @@ -436,7 +527,6 @@ class RepetitionPenalty(Warper): return scores @classmethod - # def jax_static(cls, scores: jnp.array) -> jnp.array: def jax_static( cls, logits: jnp.array, @@ -570,4 +660,4 @@ class RepetitionPenalty(Warper): @classmethod def value_is_valid(cls) -> bool: - return cls.rep_pen != 1.0 \ No newline at end of file + return cls.rep_pen != 1.0