Move bad token grabber until after newlinemode has been deduced

This commit is contained in:
somebody
2023-05-02 20:23:36 -05:00
parent efe268df60
commit a0f4ab5c6a

View File

@@ -18,20 +18,6 @@ class HFInferenceModel(InferenceModel):
self.tokenizer = None
def _post_load(self) -> None:
# Clean up tokens that cause issues
if (
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
):
utils.koboldai_vars.badwordsids = [
[v]
for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]")
]
if utils.koboldai_vars.newlinemode == "n":
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
# These are model specific tokenizer overrides if a model has bad defaults
if utils.koboldai_vars.model_type == "llama":
self.tokenizer.decode_with_prefix_space = True
@@ -49,6 +35,20 @@ class HFInferenceModel(InferenceModel):
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
utils.koboldai_vars.newlinemode = "ns"
# Clean up tokens that cause issues
if (
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
):
utils.koboldai_vars.badwordsids = [
[v]
for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]")
]
if utils.koboldai_vars.newlinemode == "n":
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
return super()._post_load()
def get_local_model_path(