Make `dynamic_processor_wrap` execute warper conditionally

The top-k warper doesn't work properly with an argument of 0, so there
is now the ability to not execute the warper if a condition is not met.
This commit is contained in:
Gnome Ann 2021-12-22 23:46:25 -05:00
parent 7e06c25011
commit 2a4d7448be
1 changed files with 14 additions and 21 deletions

View File

@ -621,16 +621,19 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Patch transformers to use our custom logit warpers # Patch transformers to use our custom logit warpers
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
def dynamic_processor_wrap(cls, field_name, var_name): def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__ old_call = cls.__call__
def new_call(self, *args, **kwargs): def new_call(self, *args, **kwargs):
setattr(self, field_name, getattr(vars, var_name)) setattr(self, field_name, getattr(vars, var_name))
return old_call(self, *args, **kwargs) assert len(args) == 2
if(cond is None or cond(getattr(vars, var_name))):
return old_call(self, *args, **kwargs)
return args[1]
cls.__call__ = new_call cls.__call__ = new_call
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen") dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k") dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p") dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp") dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
class TailFreeLogitsWarper(LogitsWarper): class TailFreeLogitsWarper(LogitsWarper):
@ -712,27 +715,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
def new_get_logits_warper( def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
top_k: int = None,
top_p: float = None,
tfs: float = None,
temp: float = None,
beams: int = 1,
) -> LogitsProcessorList:
warper_list = LogitsProcessorList() warper_list = LogitsProcessorList()
warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TemperatureLogitsWarper(temperature=temp)) warper_list.append(TemperatureLogitsWarper(temperature=0.5))
return warper_list return warper_list
def new_sample(self, *args, **kwargs): def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = new_get_logits_warper( kwargs["logits_warper"] = new_get_logits_warper(
top_k=vars.top_k,
top_p=vars.top_p,
tfs=vars.tfs,
temp=vars.temp,
beams=1, beams=1,
) )
return new_sample.old_sample(self, *args, **kwargs) return new_sample.old_sample(self, *args, **kwargs)