Always use all logit warpers

Now that the logit warper parameters can be changed mid-generation by
generation modifiers, the logit warpers have to be always on.
This commit is contained in:
Gnome Ann 2021-12-22 17:29:07 -05:00
parent 1e1b45d47a
commit c549ea04a9
1 changed files with 4 additions and 8 deletions

View File

@ -717,14 +717,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
beams: int = 1, beams: int = 1,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
warper_list = LogitsProcessorList() warper_list = LogitsProcessorList()
if(top_k is not None and top_k > 0): warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1)))
if(top_p is not None and top_p < 1.0): warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
warper_list.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1 + (beams > 1))) warper_list.append(TemperatureLogitsWarper(temperature=temp))
if(tfs is not None and tfs < 1.0):
warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1)))
if(temp is not None and temp != 1.0):
warper_list.append(TemperatureLogitsWarper(temperature=temp))
return warper_list return warper_list
def new_sample(self, *args, **kwargs): def new_sample(self, *args, **kwargs):