diff --git a/aiserver.py b/aiserver.py index b845330e..91543514 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2452,6 +2452,29 @@ def patch_transformers(): return True return False + class SinglelineStopper(StoppingCriteria): + # If singleline mode is enabled, it's pointless to generate output beyond the first newline. + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + if not koboldai_vars.singleline: + return False + + data = [tokenizer.decode(x) for x in input_ids] + if 'completed' not in self.__dict__: + self.completed = [False]*len(input_ids) + + for i in range(len(input_ids)): + if (re.compile(r'\n').search(data[i][-1 * (len(koboldai_vars.chatname) + 1):]) is not None): + self.completed[i] = True + + return self.completed[i] class CoreStopper(StoppingCriteria): # Controls core generation stuff; aborting, counting generated tokens, etc @@ -2560,6 +2583,7 @@ def patch_transformers(): token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) + stopping_criteria.insert(0, SinglelineStopper(tokenizer=tokenizer)) stopping_criteria.insert(0, self.kai_scanner) token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, token_streamer)