mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Added stopping criteria for chat mode to prevent it from generating text for user
This commit is contained in:
44
aiserver.py
44
aiserver.py
@@ -2177,6 +2177,46 @@ def patch_transformers():
|
|||||||
koboldai_vars.actions.stream_tokens(data)
|
koboldai_vars.actions.stream_tokens(data)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class ChatModeStopper(StoppingCriteria):
|
||||||
|
# A StoppingCriteria is used here because it seems to run after
|
||||||
|
# everything has been evaluated score-wise.
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> bool:
|
||||||
|
|
||||||
|
if not koboldai_vars.chatmode:
|
||||||
|
return False
|
||||||
|
|
||||||
|
data = [tokenizer.decode(x) for x in input_ids]
|
||||||
|
null_character = tokenizer.encode(chr(0))[0]
|
||||||
|
if 'completed' not in self.__dict__:
|
||||||
|
self.completed = [False]*len(input_ids)
|
||||||
|
for i in range(len(input_ids)):
|
||||||
|
if input_ids[i][-2] == null_character:
|
||||||
|
input_ids[i][-1] = tokenizer.encode(chr(0))[0]
|
||||||
|
elif data[i][-1*(len(koboldai_vars.chatname)+1):] == koboldai_vars.chatname + ":":
|
||||||
|
#We now have the user name in chat mode with a :. We want to remove that from the data
|
||||||
|
#We do need to check if the first token includes more than the chatname (Ie " You") so we don't loose the extra data
|
||||||
|
chatname_tokens = len(tokenizer.encode(koboldai_vars.chatname+":"))
|
||||||
|
if input_ids[i][-1*chatname_tokens] != koboldai_vars.chatname[0]:
|
||||||
|
input_ids[i][-1*chatname_tokens] = tokenizer.encode(tokenizer.decode(input_ids[i][-1*chatname_tokens]))[0]
|
||||||
|
else:
|
||||||
|
input_ids[i][-1*chatname_tokens] = null_character
|
||||||
|
for j in range(len(koboldai_vars.chatname)+1):
|
||||||
|
input_ids[i][-j] = tokenizer.encode(chr(0))[0]
|
||||||
|
self.completed[i] = True
|
||||||
|
if all(self.completed):
|
||||||
|
koboldai_vars.generated_tkns = koboldai_vars.genamt
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class CoreStopper(StoppingCriteria):
|
class CoreStopper(StoppingCriteria):
|
||||||
# Controls core generation stuff; aborting, counting generated tokens, etc
|
# Controls core generation stuff; aborting, counting generated tokens, etc
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -2286,6 +2326,7 @@ def patch_transformers():
|
|||||||
stopping_criteria.insert(0, self.kai_scanner)
|
stopping_criteria.insert(0, self.kai_scanner)
|
||||||
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
token_streamer = TokenStreamer(tokenizer=tokenizer)
|
||||||
stopping_criteria.insert(0, token_streamer)
|
stopping_criteria.insert(0, token_streamer)
|
||||||
|
stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer))
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
|
||||||
|
|
||||||
@@ -3091,6 +3132,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
|
|
||||||
koboldai_vars.presets = to_use
|
koboldai_vars.presets = to_use
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||||
|
|
||||||
koboldai_vars.aibusy = False
|
koboldai_vars.aibusy = False
|
||||||
if not os.path.exists("./softprompts"):
|
if not os.path.exists("./softprompts"):
|
||||||
os.mkdir("./softprompts")
|
os.mkdir("./softprompts")
|
||||||
|
Reference in New Issue
Block a user