Created SinglelineStopper, which interrupts token generation when a newline is reached if singleline mode is enabled

This commit is contained in:
Ondřej Benda
2023-03-02 00:04:17 +01:00
parent 9265be58e9
commit 0ba7ac96d3

View File

@@ -2452,6 +2452,29 @@ def patch_transformers():
return True
return False
class SinglelineStopper(StoppingCriteria):
# If singleline mode is enabled, it's pointless to generate output beyond the first newline.
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if not koboldai_vars.singleline:
return False
data = [tokenizer.decode(x) for x in input_ids]
if 'completed' not in self.__dict__:
self.completed = [False]*len(input_ids)
for i in range(len(input_ids)):
if (re.compile(r'\n').search(data[i][-1 * (len(koboldai_vars.chatname) + 1):]) is not None):
self.completed[i] = True
return self.completed[i]
class CoreStopper(StoppingCriteria):
# Controls core generation stuff; aborting, counting generated tokens, etc
@@ -2560,6 +2583,7 @@ def patch_transformers():
token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
stopping_criteria.insert(0, SinglelineStopper(tokenizer=tokenizer))
stopping_criteria.insert(0, self.kai_scanner)
token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, token_streamer)