Model: Fix assorted bugs

and ignore warnings in pytest
This commit is contained in:
somebody
2023-03-09 20:59:27 -06:00
parent 3646aa9e83
commit 8c8bdfaf6a
6 changed files with 17 additions and 23 deletions

View File

@@ -552,9 +552,6 @@ class InferenceModel:
assert isinstance(prompt_tokens, np.ndarray)
assert len(prompt_tokens.shape) == 1
if utils.koboldai_vars.model == "ReadOnly":
raise NotImplementedError("No loaded model")
time_start = time.time()
with use_core_manipulations():

View File

@@ -33,10 +33,11 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
# if utils.koboldai_vars.model not in ["NeoCustom", "GPT2Custom"]:
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
if utils.koboldai_vars.model == "NeoCustom":
utils.koboldai_vars.model = os.path.basename(
if self.model_name == "NeoCustom":
self.model_name = os.path.basename(
os.path.normpath(utils.koboldai_vars.custmodpth)
)
utils.koboldai_vars.model = self.model_name
# If we specify a model and it's in the root directory, we need to move
# it to the models directory (legacy folder structure to new)
@@ -123,8 +124,8 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
return old_rebuild_tensor(storage, storage_offset, shape, stride)
torch._utils._rebuild_tensor = new_rebuild_tensor
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
self.model = self._get_model(self.model_name, tf_kwargs)
self.tokenizer = self._get_tokenizer(self.model_name)
torch._utils._rebuild_tensor = old_rebuild_tensor
if save_model:
@@ -153,7 +154,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
shutil.move(
os.path.realpath(
huggingface_hub.hf_hub_download(
utils.koboldai_vars.model,
self.model_name,
transformers.configuration_utils.CONFIG_NAME,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
@@ -177,7 +178,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
shutil.move(
os.path.realpath(
huggingface_hub.hf_hub_download(
utils.koboldai_vars.model,
self.model_name,
possible_weight_name,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
@@ -214,7 +215,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
shutil.move(
os.path.realpath(
huggingface_hub.hf_hub_download(
utils.koboldai_vars.model,
self.model_name,
filename,
revision=utils.koboldai_vars.revision,
cache_dir="cache",

View File

@@ -8,9 +8,13 @@ from modeling.inference_model import InferenceModel
class HFInferenceModel(InferenceModel):
def __init__(self) -> None:
def __init__(self, model_name: str) -> None:
super().__init__()
self.model_config = None
self.model_name = model_name
self.model = None
self.tokenizer = None
def get_local_model_path(
self, legacy: bool = False, ignore_existance: bool = False
@@ -34,7 +38,7 @@ class HFInferenceModel(InferenceModel):
# Get the model_type from the config or assume a model type if it isn't present
try:
self.model_config = AutoConfig.from_pretrained(
self.get_local_model_path() or utils.koboldai_vars.model,
self.get_local_model_path() or self.model_name,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
)

View File

@@ -30,12 +30,8 @@ class HFMTJInferenceModel(HFInferenceModel):
self,
model_name: str,
) -> None:
super().__init__()
super().__init__(model_name)
self.model_name = model_name
self.model = None
self.tokenizer = None
self.model_config = None
self.capabilties = ModelCapabilities(
embedding_manipulation=False,

View File

@@ -61,9 +61,7 @@ class HFTorchInferenceModel(HFInferenceModel):
lazy_load: bool,
low_mem: bool,
) -> None:
super().__init__()
self.model_name = model_name
super().__init__(model_name)
self.lazy_load = lazy_load
self.low_mem = low_mem
@@ -78,8 +76,6 @@ class HFTorchInferenceModel(HFInferenceModel):
Stoppers.chat_mode_stopper,
]
self.model = None
self.tokenizer = None
self.capabilties = ModelCapabilities(
embedding_manipulation=True,
post_token_hooks=True,

View File

@@ -1,3 +1,3 @@
[pytest]
addopts = --ignore=miniconda3 --ignore=runtime --html=unit_test_report.html --self-contained-html -vv
addopts = --ignore=miniconda3 --ignore=runtime --html=unit_test_report.html --self-contained-html --disable-warnings -vv
norecursedirs = .git