mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
HFInferenceModel: Make badwordsids not unique to torch
This commit is contained in:
Submodule KoboldAI-Horde-Bridge updated: d9014ebac9...7a7327804f
@@ -239,16 +239,6 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
)
|
)
|
||||||
shutil.rmtree("cache/")
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
if (
|
|
||||||
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
|
|
||||||
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
|
||||||
):
|
|
||||||
utils.koboldai_vars.badwordsids = [
|
|
||||||
[v]
|
|
||||||
for k, v in self.tokenizer.get_vocab().items()
|
|
||||||
if any(c in str(k) for c in "[]")
|
|
||||||
]
|
|
||||||
|
|
||||||
self.patch_embedding()
|
self.patch_embedding()
|
||||||
|
|
||||||
if utils.koboldai_vars.hascuda:
|
if utils.koboldai_vars.hascuda:
|
||||||
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
import koboldai_settings
|
||||||
from logger import logger
|
from logger import logger
|
||||||
from modeling.inference_model import InferenceModel
|
from modeling.inference_model import InferenceModel
|
||||||
|
|
||||||
@@ -16,6 +17,23 @@ class HFInferenceModel(InferenceModel):
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
|
def _post_load(self) -> None:
|
||||||
|
# Clean up tokens that cause issues
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
||||||
|
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||||
|
):
|
||||||
|
utils.koboldai_vars.badwordsids = [
|
||||||
|
[v]
|
||||||
|
for k, v in self.tokenizer.get_vocab().items()
|
||||||
|
if any(c in str(k) for c in "[]")
|
||||||
|
]
|
||||||
|
|
||||||
|
if utils.koboldai_vars.newlinemode == "n":
|
||||||
|
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
|
||||||
|
|
||||||
|
return super()._post_load()
|
||||||
|
|
||||||
def get_local_model_path(
|
def get_local_model_path(
|
||||||
self, legacy: bool = False, ignore_existance: bool = False
|
self, legacy: bool = False, ignore_existance: bool = False
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
@@ -220,6 +220,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||||
use_core_manipulations.sample = new_sample
|
use_core_manipulations.sample = new_sample
|
||||||
|
|
||||||
|
return super()._post_load()
|
||||||
|
|
||||||
def _raw_generate(
|
def _raw_generate(
|
||||||
self,
|
self,
|
||||||
prompt_tokens: Union[List[int], torch.Tensor],
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
Reference in New Issue
Block a user