Potential BadWords fix

This commit is contained in:
ebolam
2023-05-24 09:08:34 -04:00
parent c61e2b676a
commit 068173b24a

View File

@@ -203,7 +203,7 @@ class HFInferenceModel(InferenceModel):
pass pass
def _post_load(self) -> None: def _post_load(self) -> None:
utils.koboldai_vars.badwordsids = koboldai_settings.badwordsids_default self.badwordsids = koboldai_settings.badwordsids_default
self.model_type = str(self.model_config.model_type) self.model_type = str(self.model_config.model_type)
# These are model specific tokenizer overrides if a model has bad defaults # These are model specific tokenizer overrides if a model has bad defaults
if self.model_type == "llama": if self.model_type == "llama":
@@ -305,17 +305,17 @@ class HFInferenceModel(InferenceModel):
# Clean up tokens that cause issues # Clean up tokens that cause issues
if ( if (
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default self.badwordsids == koboldai_settings.badwordsids_default
and self.model_type not in ("gpt2", "gpt_neo", "gptj") and self.model_type not in ("gpt2", "gpt_neo", "gptj")
): ):
utils.koboldai_vars.badwordsids = [ self.badwordsids = [
[v] [v]
for k, v in self.tokenizer.get_vocab().items() for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]") if any(c in str(k) for c in "[]")
] ]
if utils.koboldai_vars.newlinemode == "n": if utils.koboldai_vars.newlinemode == "n":
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id]) self.badwordsids.append([self.tokenizer.eos_token_id])
return super()._post_load() return super()._post_load()