mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Potential fix for gpt-neo and gpt-j
This commit is contained in:
@@ -199,8 +199,9 @@ class HFInferenceModel(InferenceModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
|
self.model_type = str(self.model_config.model_type)
|
||||||
# These are model specific tokenizer overrides if a model has bad defaults
|
# These are model specific tokenizer overrides if a model has bad defaults
|
||||||
if utils.koboldai_vars.model_type == "llama":
|
if self.model_type == "llama":
|
||||||
# Note: self.tokenizer is a GenericTokenizer, and self.tokenizer.tokenizer is the actual LlamaTokenizer
|
# Note: self.tokenizer is a GenericTokenizer, and self.tokenizer.tokenizer is the actual LlamaTokenizer
|
||||||
self.tokenizer.add_bos_token = False
|
self.tokenizer.add_bos_token = False
|
||||||
|
|
||||||
@@ -284,23 +285,23 @@ class HFInferenceModel(InferenceModel):
|
|||||||
return result
|
return result
|
||||||
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
|
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
|
||||||
|
|
||||||
elif utils.koboldai_vars.model_type == "opt":
|
elif self.model_type == "opt":
|
||||||
self.tokenizer._koboldai_header = self.tokenizer.encode("")
|
self.tokenizer._koboldai_header = self.tokenizer.encode("")
|
||||||
self.tokenizer.add_bos_token = False
|
self.tokenizer.add_bos_token = False
|
||||||
self.tokenizer.add_prefix_space = False
|
self.tokenizer.add_prefix_space = False
|
||||||
|
|
||||||
# Change newline behavior to match model quirks
|
# Change newline behavior to match model quirks
|
||||||
if utils.koboldai_vars.model_type == "xglm":
|
if self.model_type == "xglm":
|
||||||
# Default to </s> newline mode if using XGLM
|
# Default to </s> newline mode if using XGLM
|
||||||
utils.koboldai_vars.newlinemode = "s"
|
utils.koboldai_vars.newlinemode = "s"
|
||||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
elif self.model_type in ["opt", "bloom"]:
|
||||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||||
utils.koboldai_vars.newlinemode = "ns"
|
utils.koboldai_vars.newlinemode = "ns"
|
||||||
|
|
||||||
# Clean up tokens that cause issues
|
# Clean up tokens that cause issues
|
||||||
if (
|
if (
|
||||||
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
||||||
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||||
):
|
):
|
||||||
utils.koboldai_vars.badwordsids = [
|
utils.koboldai_vars.badwordsids = [
|
||||||
[v]
|
[v]
|
||||||
@@ -357,15 +358,15 @@ class HFInferenceModel(InferenceModel):
|
|||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
)
|
)
|
||||||
utils.koboldai_vars.model_type = self.model_config.model_type
|
self.model_type = self.model_config.model_type
|
||||||
except ValueError:
|
except ValueError:
|
||||||
utils.koboldai_vars.model_type = {
|
self.model_type = {
|
||||||
"NeoCustom": "gpt_neo",
|
"NeoCustom": "gpt_neo",
|
||||||
"GPT2Custom": "gpt2",
|
"GPT2Custom": "gpt2",
|
||||||
}.get(utils.koboldai_vars.model)
|
}.get(self.model)
|
||||||
|
|
||||||
if not utils.koboldai_vars.model_type:
|
if not self.model_type:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||||
)
|
)
|
||||||
utils.koboldai_vars.model_type = "gpt_neo"
|
self.model_type = "gpt_neo"
|
Reference in New Issue
Block a user