mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Move overrides to better places
This commit is contained in:
@@ -197,12 +197,6 @@ class InferenceModel:
|
||||
Returns:
|
||||
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
|
||||
"""
|
||||
if utils.koboldai_vars.model_type == "xglm":
|
||||
# Default to </s> newline mode if using XGLM
|
||||
utils.koboldai_vars.newlinemode = "s"
|
||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||
utils.koboldai_vars.newlinemode = "ns"
|
||||
|
||||
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
|
||||
|
||||
|
@@ -32,6 +32,23 @@ class HFInferenceModel(InferenceModel):
|
||||
if utils.koboldai_vars.newlinemode == "n":
|
||||
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
|
||||
|
||||
# These are model specific tokenizer overrides if a model has bad defaults
|
||||
if utils.koboldai_vars.model_type == "llama":
|
||||
self.tokenizer.decode_with_prefix_space = True
|
||||
self.tokenizer.add_bos_token = False
|
||||
elif utils.koboldai_vars.model_type == "opt":
|
||||
self.tokenizer._koboldai_header = self.tokenizer.encode("")
|
||||
self.tokenizer.add_bos_token = False
|
||||
self.tokenizer.add_prefix_space = False
|
||||
|
||||
# Change newline behavior to match model quirks
|
||||
if utils.koboldai_vars.model_type == "xglm":
|
||||
# Default to </s> newline mode if using XGLM
|
||||
utils.koboldai_vars.newlinemode = "s"
|
||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||
utils.koboldai_vars.newlinemode = "ns"
|
||||
|
||||
return super()._post_load()
|
||||
|
||||
def get_local_model_path(
|
||||
|
@@ -132,15 +132,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
if not utils.koboldai_vars.model_type:
|
||||
utils.koboldai_vars.model_type = m_self.get_model_type()
|
||||
|
||||
# These are model specific overrides if a model has bad defaults
|
||||
if utils.koboldai_vars.model_type == "llama":
|
||||
m_self.tokenizer.decode_with_prefix_space = True
|
||||
m_self.tokenizer.add_bos_token = False
|
||||
elif utils.koboldai_vars.model_type == "opt":
|
||||
m_self.tokenizer._koboldai_header = m_self.tokenizer.encode("")
|
||||
m_self.tokenizer.add_bos_token = False
|
||||
m_self.tokenizer.add_prefix_space = False
|
||||
|
||||
# Patch stopping_criteria
|
||||
class PTHStopper(StoppingCriteria):
|
||||
def __call__(
|
||||
|
Reference in New Issue
Block a user