From 49fa63052f22ae1a0dd6470d9ea5afdce89b0269 Mon Sep 17 00:00:00 2001 From: Henk Date: Tue, 29 Aug 2023 20:51:09 +0200 Subject: [PATCH] Allow EOS unbanning --- aiserver.py | 5 +++- gensettings.py | 16 ++++++++++ koboldai_settings.py | 1 + modeling/inference_models/hf_torch.py | 42 ++++++++++++++++++++------- 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/aiserver.py b/aiserver.py index ba3be3d4..40ff9c5a 100644 --- a/aiserver.py +++ b/aiserver.py @@ -930,7 +930,7 @@ tags = [ api_version = None # This gets set automatically so don't change this value api_v1 = KoboldAPISpec( - version="1.2.3", + version="1.2.4", prefixes=["/api/v1", "/api/latest"], tags=tags, ) @@ -8161,6 +8161,7 @@ class GenerationInputSchema(SamplerSettingsSchema): frmtrmblln: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines in the output with one newline.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."}) frmtrmspch: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes `#/@%{}+=~|\^<>` from the output.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."}) singleline: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes everything after the first line of the output, including the newline.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."}) + use_default_badwordids: bool = fields.Boolean(load_default=True, metadata={"description": "Ban tokens that commonly worsen the writing experience for continuous story writing"}) disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"}) frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."}) quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."}) @@ -8169,6 +8170,7 @@ class GenerationInputSchema(SamplerSettingsSchema): sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."}) stop_sequence: Optional[List[str]] = fields.List(fields.String(),metadata={"description": "An array of string sequences where the API will stop generating further tokens. The returned text WILL contain the stop sequence."}, validate=[validate.Length(max=10)]) + class GenerationResultSchema(KoboldSchema): text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."}) @@ -8311,6 +8313,7 @@ def _generate_text(body: GenerationInputSchema): "sampler_order": ("koboldai_vars", "sampler_order", None), "sampler_full_determinism": ("koboldai_vars", "full_determinism", None), "stop_sequence": ("koboldai_vars", "stop_sequence", None), + "use_default_badwordids": ("koboldai_vars", "use_default_badwordids", None), } saved_settings = {} set_aibusy(1) diff --git a/gensettings.py b/gensettings.py index 4b395266..8bb28513 100644 --- a/gensettings.py +++ b/gensettings.py @@ -396,6 +396,22 @@ gensettingstf = [ "name": "output_streaming", "ui_level": 1 }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Ban Bad Tokens", + "id": "setusedefaultbadwordids", + "min": 0, + "max": 1, + "step": 1, + "default": 1, + "tooltip": "Ban tokens that commonly worsen the writing experience for continuous story writing.", + "menu_path": "Settings", + "sub_path": "Sampling", + "classname": "model", + "name": "use_default_badwordids", + "ui_level": 0 + }, { "uitype": "toggle", "unit": "bool", diff --git a/koboldai_settings.py b/koboldai_settings.py index 30d7f0f7..5598eb62 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -693,6 +693,7 @@ class model_settings(settings): self._koboldai_vars = koboldai_vars self.alt_multi_gen = False self.bit_8_available = None + self.use_default_badwordids = True self.supported_gen_modes = [] def reset_for_model_load(self): diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 82e60304..5e6e0a95 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -330,19 +330,39 @@ class HFTorchInferenceModel(HFInferenceModel): if seed is not None: torch.manual_seed(seed) + if utils.koboldai_vars.use_default_badwordids: + self.active_badwordids = self.badwordsids + additional_bad_words_ids + else: + if additional_bad_words_ids: + self.active_badwordids = additional_bad_words_ids + else: + self.active_badwordids = None + with torch.no_grad(): start_time = time.time() - genout = self.model.generate( - input_ids=gen_in, - do_sample=True, - max_length=min( - len(prompt_tokens) + max_new, utils.koboldai_vars.max_length - ), - repetition_penalty=1.0, - bad_words_ids=self.badwordsids + additional_bad_words_ids, - use_cache=True, - num_return_sequences=batch_count, - ) + if self.active_badwordids: ## I know duplicating this is ugly, but HF checks if its present and accepts nothing but actual token bans if its there (Which I can't guarantee would be universal enough).... - Henk + genout = self.model.generate( + input_ids=gen_in, + do_sample=True, + max_length=min( + len(prompt_tokens) + max_new, utils.koboldai_vars.max_length + ), + repetition_penalty=1.0, + bad_words_ids=self.active_badwordids, + use_cache=True, + num_return_sequences=batch_count, + ) + else: + genout = self.model.generate( + input_ids=gen_in, + do_sample=True, + max_length=min( + len(prompt_tokens) + max_new, utils.koboldai_vars.max_length + ), + repetition_penalty=1.0, + use_cache=True, + num_return_sequences=batch_count, + ) logger.debug( "torch_raw_generate: run generator {}s".format(time.time() - start_time) )