mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Potential BadWords fix
This commit is contained in:
@@ -203,7 +203,7 @@ class HFInferenceModel(InferenceModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
utils.koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
|
self.badwordsids = koboldai_settings.badwordsids_default
|
||||||
self.model_type = str(self.model_config.model_type)
|
self.model_type = str(self.model_config.model_type)
|
||||||
# These are model specific tokenizer overrides if a model has bad defaults
|
# These are model specific tokenizer overrides if a model has bad defaults
|
||||||
if self.model_type == "llama":
|
if self.model_type == "llama":
|
||||||
@@ -305,17 +305,17 @@ class HFInferenceModel(InferenceModel):
|
|||||||
|
|
||||||
# Clean up tokens that cause issues
|
# Clean up tokens that cause issues
|
||||||
if (
|
if (
|
||||||
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
self.badwordsids == koboldai_settings.badwordsids_default
|
||||||
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
|
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||||
):
|
):
|
||||||
utils.koboldai_vars.badwordsids = [
|
self.badwordsids = [
|
||||||
[v]
|
[v]
|
||||||
for k, v in self.tokenizer.get_vocab().items()
|
for k, v in self.tokenizer.get_vocab().items()
|
||||||
if any(c in str(k) for c in "[]")
|
if any(c in str(k) for c in "[]")
|
||||||
]
|
]
|
||||||
|
|
||||||
if utils.koboldai_vars.newlinemode == "n":
|
if utils.koboldai_vars.newlinemode == "n":
|
||||||
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
|
self.badwordsids.append([self.tokenizer.eos_token_id])
|
||||||
|
|
||||||
return super()._post_load()
|
return super()._post_load()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user