diff --git a/aiserver.py b/aiserver.py index bedc0e76..47fae0e0 100644 --- a/aiserver.py +++ b/aiserver.py @@ -655,6 +655,14 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme transformers.generation_utils.GenerationMixin.sample = new_sample + # Allow bad words filter to ban <|endoftext|> token + import transformers.generation_logits_process + def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int): + return new_init.old_init(self, bad_words_ids, -1) + new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ + transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init + + # Sets up dynamic world info scanner class DynamicWorldInfoScanCriteria(StoppingCriteria): def __init__(