mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #342 from one-some/model-structure-and-maybe-rwkv
Move overrides to better places
This commit is contained in:
@@ -1916,9 +1916,6 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
if koboldai_vars.model == "ReadOnly":
|
if koboldai_vars.model == "ReadOnly":
|
||||||
koboldai_vars.noai = True
|
koboldai_vars.noai = True
|
||||||
|
|
||||||
loadmodelsettings()
|
|
||||||
loadsettings()
|
|
||||||
|
|
||||||
# TODO: InferKit
|
# TODO: InferKit
|
||||||
if koboldai_vars.model == "ReadOnly" or koboldai_vars.noai:
|
if koboldai_vars.model == "ReadOnly" or koboldai_vars.noai:
|
||||||
pass
|
pass
|
||||||
@@ -1984,6 +1981,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
if model:
|
if model:
|
||||||
tokenizer = model.tokenizer
|
tokenizer = model.tokenizer
|
||||||
|
|
||||||
|
loadmodelsettings()
|
||||||
|
loadsettings()
|
||||||
|
|
||||||
lua_startup()
|
lua_startup()
|
||||||
# Load scripts
|
# Load scripts
|
||||||
load_lua_scripts()
|
load_lua_scripts()
|
||||||
|
@@ -197,12 +197,6 @@ class InferenceModel:
|
|||||||
Returns:
|
Returns:
|
||||||
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
|
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
|
||||||
"""
|
"""
|
||||||
if utils.koboldai_vars.model_type == "xglm":
|
|
||||||
# Default to </s> newline mode if using XGLM
|
|
||||||
utils.koboldai_vars.newlinemode = "s"
|
|
||||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
|
||||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
|
||||||
utils.koboldai_vars.newlinemode = "ns"
|
|
||||||
|
|
||||||
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
|
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
|
||||||
|
|
||||||
|
@@ -18,6 +18,23 @@ class HFInferenceModel(InferenceModel):
|
|||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
|
# These are model specific tokenizer overrides if a model has bad defaults
|
||||||
|
if utils.koboldai_vars.model_type == "llama":
|
||||||
|
self.tokenizer.decode_with_prefix_space = True
|
||||||
|
self.tokenizer.add_bos_token = False
|
||||||
|
elif utils.koboldai_vars.model_type == "opt":
|
||||||
|
self.tokenizer._koboldai_header = self.tokenizer.encode("")
|
||||||
|
self.tokenizer.add_bos_token = False
|
||||||
|
self.tokenizer.add_prefix_space = False
|
||||||
|
|
||||||
|
# Change newline behavior to match model quirks
|
||||||
|
if utils.koboldai_vars.model_type == "xglm":
|
||||||
|
# Default to </s> newline mode if using XGLM
|
||||||
|
utils.koboldai_vars.newlinemode = "s"
|
||||||
|
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||||
|
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||||
|
utils.koboldai_vars.newlinemode = "ns"
|
||||||
|
|
||||||
# Clean up tokens that cause issues
|
# Clean up tokens that cause issues
|
||||||
if (
|
if (
|
||||||
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
||||||
|
@@ -132,15 +132,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
if not utils.koboldai_vars.model_type:
|
if not utils.koboldai_vars.model_type:
|
||||||
utils.koboldai_vars.model_type = m_self.get_model_type()
|
utils.koboldai_vars.model_type = m_self.get_model_type()
|
||||||
|
|
||||||
# These are model specific overrides if a model has bad defaults
|
|
||||||
if utils.koboldai_vars.model_type == "llama":
|
|
||||||
m_self.tokenizer.decode_with_prefix_space = True
|
|
||||||
m_self.tokenizer.add_bos_token = False
|
|
||||||
elif utils.koboldai_vars.model_type == "opt":
|
|
||||||
m_self.tokenizer._koboldai_header = m_self.tokenizer.encode("")
|
|
||||||
m_self.tokenizer.add_bos_token = False
|
|
||||||
m_self.tokenizer.add_prefix_space = False
|
|
||||||
|
|
||||||
# Patch stopping_criteria
|
# Patch stopping_criteria
|
||||||
class PTHStopper(StoppingCriteria):
|
class PTHStopper(StoppingCriteria):
|
||||||
def __call__(
|
def __call__(
|
||||||
|
Reference in New Issue
Block a user