HFInferenceModel: Make badwordsids not unique to torch

This commit is contained in:
somebody
2023-05-01 17:13:33 -05:00
parent c95be636a4
commit 933dbd634a
4 changed files with 21 additions and 11 deletions

View File

@@ -239,16 +239,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
)
shutil.rmtree("cache/")
if (
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
):
utils.koboldai_vars.badwordsids = [
[v]
for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]")
]
self.patch_embedding()
if utils.koboldai_vars.hascuda:

View File

@@ -3,6 +3,7 @@ from typing import Optional
from transformers import AutoConfig
import utils
import koboldai_settings
from logger import logger
from modeling.inference_model import InferenceModel
@@ -16,6 +17,23 @@ class HFInferenceModel(InferenceModel):
self.model = None
self.tokenizer = None
def _post_load(self) -> None:
# Clean up tokens that cause issues
if (
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
):
utils.koboldai_vars.badwordsids = [
[v]
for k, v in self.tokenizer.get_vocab().items()
if any(c in str(k) for c in "[]")
]
if utils.koboldai_vars.newlinemode == "n":
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
return super()._post_load()
def get_local_model_path(
self, legacy: bool = False, ignore_existance: bool = False
) -> Optional[str]:

View File

@@ -220,6 +220,8 @@ class HFTorchInferenceModel(HFInferenceModel):
new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample
return super()._post_load()
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],