From 0ba7ac96d36d02d9b6cf688695356be6f7cbe172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Benda?= Date: Thu, 2 Mar 2023 00:04:17 +0100 Subject: [PATCH 1/2] Created SinglelineStopper, which interrupts token generation when a newline is reached if singleline mode is enabled --- aiserver.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/aiserver.py b/aiserver.py index b845330e..91543514 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2452,6 +2452,29 @@ def patch_transformers(): return True return False + class SinglelineStopper(StoppingCriteria): + # If singleline mode is enabled, it's pointless to generate output beyond the first newline. + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + if not koboldai_vars.singleline: + return False + + data = [tokenizer.decode(x) for x in input_ids] + if 'completed' not in self.__dict__: + self.completed = [False]*len(input_ids) + + for i in range(len(input_ids)): + if (re.compile(r'\n').search(data[i][-1 * (len(koboldai_vars.chatname) + 1):]) is not None): + self.completed[i] = True + + return self.completed[i] class CoreStopper(StoppingCriteria): # Controls core generation stuff; aborting, counting generated tokens, etc @@ -2560,6 +2583,7 @@ def patch_transformers(): token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) + stopping_criteria.insert(0, SinglelineStopper(tokenizer=tokenizer)) stopping_criteria.insert(0, self.kai_scanner) token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, token_streamer) From aa124b65dbb7aea858ecd85e922f4d8134d98000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Benda?= Date: Thu, 2 Mar 2023 08:22:31 +0100 Subject: [PATCH 2/2] Fix: incorrect newline evaluation --- aiserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 91543514..8ee83707 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2471,7 +2471,7 @@ def patch_transformers(): self.completed = [False]*len(input_ids) for i in range(len(input_ids)): - if (re.compile(r'\n').search(data[i][-1 * (len(koboldai_vars.chatname) + 1):]) is not None): + if data[i][-1] == "\n": self.completed[i] = True return self.completed[i]