Added stopping criteria for chat mode to prevent it from generating text for user

This commit is contained in:
ebolam
2022-10-15 18:14:10 -04:00
parent 9b9f9e9a44
commit c3f44adea8

View File

@@ -2176,6 +2176,46 @@ def patch_transformers():
data = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(x[-1])), no_sentence_trimming=True) for x in input_ids]
koboldai_vars.actions.stream_tokens(data)
return False
class ChatModeStopper(StoppingCriteria):
# A StoppingCriteria is used here because it seems to run after
# everything has been evaluated score-wise.
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if not koboldai_vars.chatmode:
return False
data = [tokenizer.decode(x) for x in input_ids]
null_character = tokenizer.encode(chr(0))[0]
if 'completed' not in self.__dict__:
self.completed = [False]*len(input_ids)
for i in range(len(input_ids)):
if input_ids[i][-2] == null_character:
input_ids[i][-1] = tokenizer.encode(chr(0))[0]
elif data[i][-1*(len(koboldai_vars.chatname)+1):] == koboldai_vars.chatname + ":":
#We now have the user name in chat mode with a :. We want to remove that from the data
#We do need to check if the first token includes more than the chatname (Ie " You") so we don't loose the extra data
chatname_tokens = len(tokenizer.encode(koboldai_vars.chatname+":"))
if input_ids[i][-1*chatname_tokens] != koboldai_vars.chatname[0]:
input_ids[i][-1*chatname_tokens] = tokenizer.encode(tokenizer.decode(input_ids[i][-1*chatname_tokens]))[0]
else:
input_ids[i][-1*chatname_tokens] = null_character
for j in range(len(koboldai_vars.chatname)+1):
input_ids[i][-j] = tokenizer.encode(chr(0))[0]
self.completed[i] = True
if all(self.completed):
koboldai_vars.generated_tkns = koboldai_vars.genamt
return True
return False
class CoreStopper(StoppingCriteria):
# Controls core generation stuff; aborting, counting generated tokens, etc
@@ -2286,6 +2326,7 @@ def patch_transformers():
stopping_criteria.insert(0, self.kai_scanner)
token_streamer = TokenStreamer(tokenizer=tokenizer)
stopping_criteria.insert(0, token_streamer)
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
@@ -3090,6 +3131,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
to_use["Other"]['Official'].append(preset)
koboldai_vars.presets = to_use
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
koboldai_vars.aibusy = False
if not os.path.exists("./softprompts"):