Merge pull request #297 from aibosan/united-singleline-stopper

Created SinglelineStopper...
This commit is contained in:
henk717
2023-03-02 20:51:54 +01:00
committed by GitHub

View File

@@ -2453,6 +2453,29 @@ def patch_transformers():
return True return True
return False return False
class SinglelineStopper(StoppingCriteria):
# If singleline mode is enabled, it's pointless to generate output beyond the first newline.
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if not koboldai_vars.singleline:
return False
data = [tokenizer.decode(x) for x in input_ids]
if 'completed' not in self.__dict__:
self.completed = [False]*len(input_ids)
for i in range(len(input_ids)):
if data[i][-1] == "\n":
self.completed[i] = True
return self.completed[i]
class CoreStopper(StoppingCriteria): class CoreStopper(StoppingCriteria):
# Controls core generation stuff; aborting, counting generated tokens, etc # Controls core generation stuff; aborting, counting generated tokens, etc
@@ -2561,6 +2584,7 @@ def patch_transformers():
token_streamer = TokenStreamer(tokenizer=tokenizer) token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
stopping_criteria.insert(0, SinglelineStopper(tokenizer=tokenizer))
stopping_criteria.insert(0, self.kai_scanner) stopping_criteria.insert(0, self.kai_scanner)
token_streamer = TokenStreamer(tokenizer=tokenizer) token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, token_streamer) stopping_criteria.insert(0, token_streamer)