diff --git a/aiserver.py b/aiserver.py index 1a5db449..6f9bb30d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -621,16 +621,19 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # Patch transformers to use our custom logit warpers from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor - def dynamic_processor_wrap(cls, field_name, var_name): + def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ def new_call(self, *args, **kwargs): setattr(self, field_name, getattr(vars, var_name)) - return old_call(self, *args, **kwargs) + assert len(args) == 2 + if(cond is None or cond(getattr(vars, var_name))): + return old_call(self, *args, **kwargs) + return args[1] cls.__call__ = new_call - dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen") - dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k") - dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p") - dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp") + dynamic_processor_wrap(RepetitionPenaltyLogitsProcessor, "penalty", "rep_pen", cond=lambda x: x != 1.0) + dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) + dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) class TailFreeLogitsWarper(LogitsWarper): @@ -712,27 +715,17 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor - def new_get_logits_warper( - top_k: int = None, - top_p: float = None, - tfs: float = None, - temp: float = None, - beams: int = 1, - ) -> LogitsProcessorList: + def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: warper_list = LogitsProcessorList() - warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1))) - warper_list.append(TemperatureLogitsWarper(temperature=temp)) + warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) + warper_list.append(TemperatureLogitsWarper(temperature=0.5)) return warper_list def new_sample(self, *args, **kwargs): assert kwargs.pop("logits_warper", None) is not None kwargs["logits_warper"] = new_get_logits_warper( - top_k=vars.top_k, - top_p=vars.top_p, - tfs=vars.tfs, - temp=vars.temp, beams=1, ) return new_sample.old_sample(self, *args, **kwargs)