diff --git a/aiserver.py b/aiserver.py index 14ebc2eb..d22c770c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2453,6 +2453,29 @@ def patch_transformers(): return True return False + class SinglelineStopper(StoppingCriteria): + # If singleline mode is enabled, it's pointless to generate output beyond the first newline. + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + if not koboldai_vars.singleline: + return False + + data = [tokenizer.decode(x) for x in input_ids] + if 'completed' not in self.__dict__: + self.completed = [False]*len(input_ids) + + for i in range(len(input_ids)): + if data[i][-1] == "\n": + self.completed[i] = True + + return self.completed[i] class CoreStopper(StoppingCriteria): # Controls core generation stuff; aborting, counting generated tokens, etc @@ -2561,6 +2584,7 @@ def patch_transformers(): token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) + stopping_criteria.insert(0, SinglelineStopper(tokenizer=tokenizer)) stopping_criteria.insert(0, self.kai_scanner) token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, token_streamer)