mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Fix assorted bugs
and ignore warnings in pytest
This commit is contained in:
@@ -552,9 +552,6 @@ class InferenceModel:
|
||||
assert isinstance(prompt_tokens, np.ndarray)
|
||||
assert len(prompt_tokens.shape) == 1
|
||||
|
||||
if utils.koboldai_vars.model == "ReadOnly":
|
||||
raise NotImplementedError("No loaded model")
|
||||
|
||||
time_start = time.time()
|
||||
|
||||
with use_core_manipulations():
|
||||
|
@@ -33,10 +33,11 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
# if utils.koboldai_vars.model not in ["NeoCustom", "GPT2Custom"]:
|
||||
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
|
||||
|
||||
if utils.koboldai_vars.model == "NeoCustom":
|
||||
utils.koboldai_vars.model = os.path.basename(
|
||||
if self.model_name == "NeoCustom":
|
||||
self.model_name = os.path.basename(
|
||||
os.path.normpath(utils.koboldai_vars.custmodpth)
|
||||
)
|
||||
utils.koboldai_vars.model = self.model_name
|
||||
|
||||
# If we specify a model and it's in the root directory, we need to move
|
||||
# it to the models directory (legacy folder structure to new)
|
||||
@@ -123,8 +124,8 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
return old_rebuild_tensor(storage, storage_offset, shape, stride)
|
||||
|
||||
torch._utils._rebuild_tensor = new_rebuild_tensor
|
||||
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
|
||||
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
|
||||
self.model = self._get_model(self.model_name, tf_kwargs)
|
||||
self.tokenizer = self._get_tokenizer(self.model_name)
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
|
||||
if save_model:
|
||||
@@ -153,7 +154,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
shutil.move(
|
||||
os.path.realpath(
|
||||
huggingface_hub.hf_hub_download(
|
||||
utils.koboldai_vars.model,
|
||||
self.model_name,
|
||||
transformers.configuration_utils.CONFIG_NAME,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
@@ -177,7 +178,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
shutil.move(
|
||||
os.path.realpath(
|
||||
huggingface_hub.hf_hub_download(
|
||||
utils.koboldai_vars.model,
|
||||
self.model_name,
|
||||
possible_weight_name,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
@@ -214,7 +215,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
shutil.move(
|
||||
os.path.realpath(
|
||||
huggingface_hub.hf_hub_download(
|
||||
utils.koboldai_vars.model,
|
||||
self.model_name,
|
||||
filename,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
|
@@ -8,9 +8,13 @@ from modeling.inference_model import InferenceModel
|
||||
|
||||
|
||||
class HFInferenceModel(InferenceModel):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, model_name: str) -> None:
|
||||
super().__init__()
|
||||
self.model_config = None
|
||||
self.model_name = model_name
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
def get_local_model_path(
|
||||
self, legacy: bool = False, ignore_existance: bool = False
|
||||
@@ -34,7 +38,7 @@ class HFInferenceModel(InferenceModel):
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
try:
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.get_local_model_path() or utils.koboldai_vars.model,
|
||||
self.get_local_model_path() or self.model_name,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
|
@@ -30,12 +30,8 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
self,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
super().__init__(model_name)
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.model_config = None
|
||||
self.capabilties = ModelCapabilities(
|
||||
embedding_manipulation=False,
|
||||
|
@@ -61,9 +61,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
lazy_load: bool,
|
||||
low_mem: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
super().__init__(model_name)
|
||||
self.lazy_load = lazy_load
|
||||
self.low_mem = low_mem
|
||||
|
||||
@@ -78,8 +76,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
Stoppers.chat_mode_stopper,
|
||||
]
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.capabilties = ModelCapabilities(
|
||||
embedding_manipulation=True,
|
||||
post_token_hooks=True,
|
||||
|
@@ -1,3 +1,3 @@
|
||||
[pytest]
|
||||
addopts = --ignore=miniconda3 --ignore=runtime --html=unit_test_report.html --self-contained-html -vv
|
||||
addopts = --ignore=miniconda3 --ignore=runtime --html=unit_test_report.html --self-contained-html --disable-warnings -vv
|
||||
norecursedirs = .git
|
||||
|
Reference in New Issue
Block a user