mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Added stopping criteria for chat mode to prevent it from generating text for user
This commit is contained in:
44
aiserver.py
44
aiserver.py
@@ -2176,6 +2176,46 @@ def patch_transformers():
|
||||
data = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(x[-1])), no_sentence_trimming=True) for x in input_ids]
|
||||
koboldai_vars.actions.stream_tokens(data)
|
||||
return False
|
||||
|
||||
class ChatModeStopper(StoppingCriteria):
|
||||
# A StoppingCriteria is used here because it seems to run after
|
||||
# everything has been evaluated score-wise.
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
|
||||
if not koboldai_vars.chatmode:
|
||||
return False
|
||||
|
||||
data = [tokenizer.decode(x) for x in input_ids]
|
||||
null_character = tokenizer.encode(chr(0))[0]
|
||||
if 'completed' not in self.__dict__:
|
||||
self.completed = [False]*len(input_ids)
|
||||
for i in range(len(input_ids)):
|
||||
if input_ids[i][-2] == null_character:
|
||||
input_ids[i][-1] = tokenizer.encode(chr(0))[0]
|
||||
elif data[i][-1*(len(koboldai_vars.chatname)+1):] == koboldai_vars.chatname + ":":
|
||||
#We now have the user name in chat mode with a :. We want to remove that from the data
|
||||
#We do need to check if the first token includes more than the chatname (Ie " You") so we don't loose the extra data
|
||||
chatname_tokens = len(tokenizer.encode(koboldai_vars.chatname+":"))
|
||||
if input_ids[i][-1*chatname_tokens] != koboldai_vars.chatname[0]:
|
||||
input_ids[i][-1*chatname_tokens] = tokenizer.encode(tokenizer.decode(input_ids[i][-1*chatname_tokens]))[0]
|
||||
else:
|
||||
input_ids[i][-1*chatname_tokens] = null_character
|
||||
for j in range(len(koboldai_vars.chatname)+1):
|
||||
input_ids[i][-j] = tokenizer.encode(chr(0))[0]
|
||||
self.completed[i] = True
|
||||
if all(self.completed):
|
||||
koboldai_vars.generated_tkns = koboldai_vars.genamt
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class CoreStopper(StoppingCriteria):
|
||||
# Controls core generation stuff; aborting, counting generated tokens, etc
|
||||
@@ -2286,6 +2326,7 @@ def patch_transformers():
|
||||
stopping_criteria.insert(0, self.kai_scanner)
|
||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||
stopping_criteria.insert(0, token_streamer)
|
||||
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
|
||||
return stopping_criteria
|
||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||
|
||||
@@ -3090,6 +3131,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
to_use["Other"]['Official'].append(preset)
|
||||
|
||||
koboldai_vars.presets = to_use
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
|
||||
koboldai_vars.aibusy = False
|
||||
if not os.path.exists("./softprompts"):
|
||||
|
Reference in New Issue
Block a user