Move TFS warper code into aiserver.py
This commit is contained in:
parent
96e1d98b7e
commit
cbb6efb656
85
aiserver.py
85
aiserver.py
|
@ -568,6 +568,83 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.filter_value >= 1.0:
|
||||
return scores
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Compute second derivative normalized CDF
|
||||
d2 = probs.diff().diff().abs()
|
||||
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||
|
||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||
sorted_indices_to_remove = torch.cat(
|
||||
(
|
||||
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
sorted_indices_to_remove,
|
||||
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
tfs: float = None,
|
||||
temp: float = None,
|
||||
beams: int = 1,
|
||||
) -> LogitsProcessorList:
|
||||
warper_list = LogitsProcessorList()
|
||||
if(top_k is not None and top_k > 0):
|
||||
warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1)))
|
||||
if(top_p is not None and top_p < 1.0):
|
||||
warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1)))
|
||||
if(tfs is not None and tfs < 1.0):
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
|
||||
if(temp is not None and temp != 1.0):
|
||||
warper_list.append(TemperatureLogitsWarper(temperature=temp))
|
||||
return warper_list
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = new_get_logits_warper(
|
||||
vars.top_k,
|
||||
vars.top_p,
|
||||
vars.tfs,
|
||||
vars.temp,
|
||||
1,
|
||||
)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
|
||||
transformers.generation_utils.GenerationMixin.sample = new_sample
|
||||
|
||||
|
||||
# Sets up dynamic world info scanner
|
||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||
def __init__(
|
||||
|
@ -1463,10 +1540,6 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
top_p = vars.top_p if vars.top_p > 0.0 else None
|
||||
top_k = vars.top_k if vars.top_k > 0 else None
|
||||
tfs = vars.tfs if vars.tfs > 0.0 else None
|
||||
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
|
@ -1499,10 +1572,6 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||
min_length=minimum,
|
||||
max_length=maximum-already_generated,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
num_return_sequences=numseqs
|
||||
|
|
Loading…
Reference in New Issue