From 2d3db7b4ba388f566aaec88a0e76678fe4fade8d Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 13 Jun 2022 19:12:23 -0400 Subject: [PATCH] Implement support for sampler order in the backend code --- aiserver.py | 27 +++++++++++++++++++-------- tpu_mtj_backend.py | 46 +++++++++++++++++++++++++++------------------- utils.py | 2 ++ 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/aiserver.py b/aiserver.py index 6267aec2..0bed5ad8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -306,6 +306,7 @@ class vars: acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor + sampler_order = utils.default_sampler_order.copy() chatmode = False chatname = "You" adventure = False @@ -1448,15 +1449,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor + class KoboldLogitsWarperList(LogitsProcessorList): + def __init__(self, beams: int = 1, **kwargs): + self.__warper_list: List[LogitsWarper] = [] + self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) + self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): + for k in vars.sampler_order: + scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) + return scores + def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: - warper_list = LogitsProcessorList() - warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TemperatureLogitsWarper(temperature=0.5)) - return warper_list + return KoboldLogitsWarperList(beams=beams) def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None @@ -1816,6 +1825,7 @@ else: def tpumtjgenerate_settings_callback() -> dict: return { + "sampler_order": vars.sampler_order, "top_p": float(vars.top_p), "temp": float(vars.temp), "top_k": int(vars.top_k), @@ -3910,6 +3920,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): rprange=vars.rep_pen_range, soft_embeddings=vars.sp, soft_tokens=soft_tokens, + sampler_order=vars.sampler_order, ) past = genout for i in range(vars.numseqs): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index f66ad53c..67e006d6 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -65,6 +65,7 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List def settings_callback() -> dict: return { + "sampler_order": utils.default_sampler_order.copy(), "top_p": 0.9, "temp": 0.5, "top_k": 0, @@ -159,7 +160,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat logits[tokens] = penalty_logits return logits -def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -181,8 +182,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if top_k > 0: - logits = top_k_filter(logits) # Top-a (remove all tokens that have softmax probability less than # a*m^2 where m is the maximum softmax probability) def top_a_filter(logits): @@ -195,8 +194,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty probs_max = probabilities.max() # Remove tokens return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) - if top_a > 0.0: - logits = top_a_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -222,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if top_p < 1.0: - logits = top_p_filter(logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the @@ -262,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -np.inf, logits) - if tfs < 1.0: - logits = tail_free_filter(logits) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them @@ -293,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty sorted_indices_to_remove, ) return np.where(indices_to_remove, -jnp.inf, logits) - if typical < 1.0: - logits = typical_filter(logits) # Temperature (just divide the logits by the temperature) - logits /= temp + def temp_filter(logits): + return logits / temp + for k in sampler_order: + if k == 0 and top_k > 0: logits = top_k_filter(logits) + if k == 1 and top_a > 0.0: logits = top_a_filter(logits) + if k == 2 and top_p < 1.0: logits = top_p_filter(logits) + if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) + if k == 4 and typical < 1.0: logits = typical_filter(logits) + if k == 5 and temp != 1.0: logits = temp_filter(logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -347,7 +346,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): +def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' This gets called by generate_loop_fn to apply a series of 6 filters to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) @@ -369,7 +368,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) # Top-a (remove all tokens that have softmax probability less than # a*m^2 where m is the maximum softmax probability) def top_a_filter(logits): @@ -382,7 +380,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ probs_max = probabilities.max() # Remove tokens return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) - logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -408,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the @@ -447,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) def typical_filter(logits): # Compute softmax probabilities and the natural logarithms of them @@ -476,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ sorted_indices_to_remove, ) return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits) # Temperature (just divide the logits by the temperature) def temp_filter(logits): return logits / temp - logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) + for k in sampler_order: + logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) + logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) @@ -842,8 +842,12 @@ def infer_static( gen_len=80, soft_embeddings: Optional[np.array] = None, soft_tokens: Optional[np.array] = None, + sampler_order: Optional[List[int]] = None, ) -> List[np.array]: maps.thread_resources.env = thread_resources_env + if sampler_order is None: + sampler_order = utils.default_sampler_order.copy() + sampler_order = np.uint32(sampler_order) total_batch = 1 tokens = context if(soft_tokens is not None): @@ -854,6 +858,7 @@ def infer_static( batched_tokens = np.array([padded_tokens] * total_batch) samples = [] batched_generator_params = { + "sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0), "temp": temp * np.ones(total_batch), "top_p": top_p * np.ones(total_batch), "tfs": tfs * np.ones(total_batch), @@ -1015,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params + if not hasattr(vars, "sampler_order") or not vars.sampler_order: + vars.sampler_order = utils.default_sampler_order.copy() + default_params = { "compat": "j", "layers": 28, diff --git a/utils.py b/utils.py index bc085412..96606269 100644 --- a/utils.py +++ b/utils.py @@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None from_pretrained_kwargs = {} bar = None +default_sampler_order = [0, 1, 2, 3, 4, 5] + #==================================================================# # Decorator to prevent a function's actions from being run until # at least x seconds have passed without the function being called