Make `dynamic_processor_wrap` execute warper conditionally
The top-k warper doesn't work properly with an argument of 0, so there is now the ability to not execute the warper if a condition is not met.
This commit is contained in:
parent
7e06c25011
commit
2a4d7448be
35
aiserver.py
35
aiserver.py
|
@ -621,16 +621,19 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
# Patch transformers to use our custom logit warpers
|
# Patch transformers to use our custom logit warpers
|
||||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
def dynamic_processor_wrap(cls, field_name, var_name):
|
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||||
old_call = cls.__call__
|
old_call = cls.__call__
|
||||||
def new_call(self, *args, **kwargs):
|
def new_call(self, *args, **kwargs):
|
||||||
setattr(self, field_name, getattr(vars, var_name))
|
setattr(self, field_name, getattr(vars, var_name))
|
||||||
return old_call(self, *args, **kwargs)
|
assert len(args) == 2
|
||||||
|
if(cond is None or cond(getattr(vars, var_name))):
|
||||||
|
return old_call(self, *args, **kwargs)
|
||||||
|
return args[1]
|
||||||
cls.__call__ = new_call
|
cls.__call__ = new_call
|
||||||
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen")
|
dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0)
|
||||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k")
|
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p")
|
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp")
|
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
class TailFreeLogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
|
@ -712,27 +715,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||||
|
|
||||||
def new_get_logits_warper(
|
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||||
top_k: int = None,
|
|
||||||
top_p: float = None,
|
|
||||||
tfs: float = None,
|
|
||||||
temp: float = None,
|
|
||||||
beams: int = 1,
|
|
||||||
) -> LogitsProcessorList:
|
|
||||||
warper_list = LogitsProcessorList()
|
warper_list = LogitsProcessorList()
|
||||||
warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
|
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
warper_list.append(TemperatureLogitsWarper(temperature=temp))
|
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||||
return warper_list
|
return warper_list
|
||||||
|
|
||||||
def new_sample(self, *args, **kwargs):
|
def new_sample(self, *args, **kwargs):
|
||||||
assert kwargs.pop("logits_warper", None) is not None
|
assert kwargs.pop("logits_warper", None) is not None
|
||||||
kwargs["logits_warper"] = new_get_logits_warper(
|
kwargs["logits_warper"] = new_get_logits_warper(
|
||||||
top_k=vars.top_k,
|
|
||||||
top_p=vars.top_p,
|
|
||||||
tfs=vars.tfs,
|
|
||||||
temp=vars.temp,
|
|
||||||
beams=1,
|
beams=1,
|
||||||
)
|
)
|
||||||
return new_sample.old_sample(self, *args, **kwargs)
|
return new_sample.old_sample(self, *args, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue