Potential BadWords fix

This commit is contained in:
ebolam
2023-05-24 09:08:34 -04:00
parent c61e2b676a
commit 068173b24a

View File

@@ -203,7 +203,7 @@ class HFInferenceModel(InferenceModel):
pass
def _post_load(self) -> None:
utils.koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
self.badwordsids = koboldai_settings.badwordsids_default
self.model_type = str(self.model_config.model_type)
# These are model specific tokenizer overrides if a model has bad defaults
if self.model_type == "llama":
@@ -305,17 +305,17 @@ class HFInferenceModel(InferenceModel):
# Clean up tokens that cause issues
if (
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
self.badwordsids == koboldai_settings.badwordsids_default
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
):
utils.koboldai_vars.badwordsids = [
self.badwordsids = [
[v]
for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]")
]
if utils.koboldai_vars.newlinemode == "n":
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
self.badwordsids.append([self.tokenizer.eos_token_id])
return super()._post_load()