diff --git a/aiserver.py b/aiserver.py index 34906559..a27b6455 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2176,6 +2176,46 @@ def patch_transformers(): data = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(x[-1])), no_sentence_trimming=True) for x in input_ids] koboldai_vars.actions.stream_tokens(data) return False + + class ChatModeStopper(StoppingCriteria): + # A StoppingCriteria is used here because it seems to run after + # everything has been evaluated score-wise. + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + + if not koboldai_vars.chatmode: + return False + + data = [tokenizer.decode(x) for x in input_ids] + null_character = tokenizer.encode(chr(0))[0] + if 'completed' not in self.__dict__: + self.completed = [False]*len(input_ids) + for i in range(len(input_ids)): + if input_ids[i][-2] == null_character: + input_ids[i][-1] = tokenizer.encode(chr(0))[0] + elif data[i][-1*(len(koboldai_vars.chatname)+1):] == koboldai_vars.chatname + ":": + #We now have the user name in chat mode with a :. We want to remove that from the data + #We do need to check if the first token includes more than the chatname (Ie " You") so we don't loose the extra data + chatname_tokens = len(tokenizer.encode(koboldai_vars.chatname+":")) + if input_ids[i][-1*chatname_tokens] != koboldai_vars.chatname[0]: + input_ids[i][-1*chatname_tokens] = tokenizer.encode(tokenizer.decode(input_ids[i][-1*chatname_tokens]))[0] + else: + input_ids[i][-1*chatname_tokens] = null_character + for j in range(len(koboldai_vars.chatname)+1): + input_ids[i][-j] = tokenizer.encode(chr(0))[0] + self.completed[i] = True + if all(self.completed): + koboldai_vars.generated_tkns = koboldai_vars.genamt + return True + return False + class CoreStopper(StoppingCriteria): # Controls core generation stuff; aborting, counting generated tokens, etc @@ -2286,6 +2326,7 @@ def patch_transformers(): stopping_criteria.insert(0, self.kai_scanner) token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, token_streamer) + stopping_criteria.insert(0, ChatModeStopper(tokenizer=tokenizer)) return stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria @@ -3090,6 +3131,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal to_use["Other"]['Official'].append(preset) koboldai_vars.presets = to_use + + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) koboldai_vars.aibusy = False if not os.path.exists("./softprompts"):