From cbb6efb6561f5948f925e139fab66a59cbe689fc Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 24 Nov 2021 13:36:54 -0500 Subject: [PATCH] Move TFS warper code into aiserver.py --- aiserver.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/aiserver.py b/aiserver.py index 41d0c70a..1d02dd67 100644 --- a/aiserver.py +++ b/aiserver.py @@ -568,6 +568,83 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme except: pass + + # Patch transformers to use our custom logit warpers + from transformers import LogitsProcessorList, LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper + class TailFreeLogitsWarper(LogitsWarper): + + def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + tfs = float(tfs) + if tfs < 0 or tfs > 1.0: + raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}") + self.tfs = tfs + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.filter_value >= 1.0: + return scores + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + # Compute second derivative normalized CDF + d2 = probs.diff().diff().abs() + normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True) + normalized_d2_cdf = normalized_d2.cumsum(dim=-1) + + # Remove tokens with CDF value above the threshold (token with 0 are kept) + sorted_indices_to_remove = normalized_d2_cdf > self.tfs + + # Centre the distribution around the cutoff as in the original implementation of the algorithm + sorted_indices_to_remove = torch.cat( + ( + torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device), + sorted_indices_to_remove, + torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device), + ), + dim=-1, + ) + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + def new_get_logits_warper( + top_k: int = None, + top_p: float = None, + tfs: float = None, + temp: float = None, + beams: int = 1, + ) -> LogitsProcessorList: + warper_list = LogitsProcessorList() + if(top_k is not None and top_k > 0): + warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1))) + if(top_p is not None and top_p < 1.0): + warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1))) + if(tfs is not None and tfs < 1.0): + warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1))) + if(temp is not None and temp != 1.0): + warper_list.append(TemperatureLogitsWarper(temperature=temp)) + return warper_list + + def new_sample(self, *args, **kwargs): + assert kwargs.pop("logits_warper", None) is not None + kwargs["logits_warper"] = new_get_logits_warper( + vars.top_k, + vars.top_p, + vars.tfs, + vars.temp, + 1, + ) + return new_sample.old_sample(self, *args, **kwargs) + new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample + transformers.generation_utils.GenerationMixin.sample = new_sample + + # Sets up dynamic world info scanner class DynamicWorldInfoScanCriteria(StoppingCriteria): def __init__( @@ -1463,10 +1540,6 @@ def generate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: - top_p = vars.top_p if vars.top_p > 0.0 else None - top_k = vars.top_k if vars.top_k > 0 else None - tfs = vars.tfs if vars.tfs > 0.0 else None - gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long() if(vars.sp is not None): soft_tokens = torch.arange( @@ -1499,10 +1572,6 @@ def generate(txt, minimum, maximum, found_entries=None): min_length=minimum, max_length=maximum-already_generated, repetition_penalty=vars.rep_pen, - top_p=top_p, - top_k=top_k, - tfs=tfs, - temperature=vars.temp, bad_words_ids=vars.badwordsids, use_cache=True, num_return_sequences=numseqs