Make transformers warpers dynamically update their parameters

So that if you change, e.g., `top_p`, from a Lua generation modifier or
from the settings menu during generation, the rest of the generation
will use the new setting value instead of retaining the settings it had
when generation began.
This commit is contained in:
Gnome Ann 2021-12-21 22:12:24 -05:00
parent 91b6289897
commit 380b54167a
1 changed files with 15 additions and 0 deletions

View File

@ -7,6 +7,7 @@
# External packages
import eventlet
from transformers.generation_logits_process import RepetitionPenaltyLogitsProcessor
eventlet.monkey_patch()
import os
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
@ -617,6 +618,18 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
def dynamic_processor_wrap(cls, field_name, var_name):
old_call = cls.__call__
def new_call(self, *args, **kwargs):
setattr(self, field_name, getattr(vars, var_name))
return old_call(self, *args, **kwargs)
cls.__call__ = new_call
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen")
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k")
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p")
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp")
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -628,6 +641,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.tfs = vars.tfs
if self.filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)