Make `dynamic_processor_wrap` execute warper conditionally
The top-k warper doesn't work properly with an argument of 0, so there is now the ability to not execute the warper if a condition is not met.
This commit is contained in:
parent
7e06c25011
commit
2a4d7448be
35
aiserver.py
35
aiserver.py
|
@ -621,16 +621,19 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name):
|
||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||
old_call = cls.__call__
|
||||
def new_call(self, *args, **kwargs):
|
||||
setattr(self, field_name, getattr(vars, var_name))
|
||||
return old_call(self, *args, **kwargs)
|
||||
assert len(args) == 2
|
||||
if(cond is None or cond(getattr(vars, var_name))):
|
||||
return old_call(self, *args, **kwargs)
|
||||
return args[1]
|
||||
cls.__call__ = new_call
|
||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen")
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k")
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p")
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp")
|
||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
|
||||
|
@ -712,27 +715,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
||||
def new_get_logits_warper(
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
tfs: float = None,
|
||||
temp: float = None,
|
||||
beams: int = 1,
|
||||
) -> LogitsProcessorList:
|
||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||
warper_list = LogitsProcessorList()
|
||||
warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TemperatureLogitsWarper(temperature=temp))
|
||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
return warper_list
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = new_get_logits_warper(
|
||||
top_k=vars.top_k,
|
||||
top_p=vars.top_p,
|
||||
tfs=vars.tfs,
|
||||
temp=vars.temp,
|
||||
beams=1,
|
||||
)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue