mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Fix assorted bugs
and ignore warnings in pytest
This commit is contained in:
@@ -552,9 +552,6 @@ class InferenceModel:
|
|||||||
assert isinstance(prompt_tokens, np.ndarray)
|
assert isinstance(prompt_tokens, np.ndarray)
|
||||||
assert len(prompt_tokens.shape) == 1
|
assert len(prompt_tokens.shape) == 1
|
||||||
|
|
||||||
if utils.koboldai_vars.model == "ReadOnly":
|
|
||||||
raise NotImplementedError("No loaded model")
|
|
||||||
|
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
with use_core_manipulations():
|
with use_core_manipulations():
|
||||||
|
@@ -33,10 +33,11 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
# if utils.koboldai_vars.model not in ["NeoCustom", "GPT2Custom"]:
|
# if utils.koboldai_vars.model not in ["NeoCustom", "GPT2Custom"]:
|
||||||
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
|
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
|
||||||
|
|
||||||
if utils.koboldai_vars.model == "NeoCustom":
|
if self.model_name == "NeoCustom":
|
||||||
utils.koboldai_vars.model = os.path.basename(
|
self.model_name = os.path.basename(
|
||||||
os.path.normpath(utils.koboldai_vars.custmodpth)
|
os.path.normpath(utils.koboldai_vars.custmodpth)
|
||||||
)
|
)
|
||||||
|
utils.koboldai_vars.model = self.model_name
|
||||||
|
|
||||||
# If we specify a model and it's in the root directory, we need to move
|
# If we specify a model and it's in the root directory, we need to move
|
||||||
# it to the models directory (legacy folder structure to new)
|
# it to the models directory (legacy folder structure to new)
|
||||||
@@ -123,8 +124,8 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
return old_rebuild_tensor(storage, storage_offset, shape, stride)
|
return old_rebuild_tensor(storage, storage_offset, shape, stride)
|
||||||
|
|
||||||
torch._utils._rebuild_tensor = new_rebuild_tensor
|
torch._utils._rebuild_tensor = new_rebuild_tensor
|
||||||
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
|
self.model = self._get_model(self.model_name, tf_kwargs)
|
||||||
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
|
self.tokenizer = self._get_tokenizer(self.model_name)
|
||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
|
|
||||||
if save_model:
|
if save_model:
|
||||||
@@ -153,7 +154,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
shutil.move(
|
shutil.move(
|
||||||
os.path.realpath(
|
os.path.realpath(
|
||||||
huggingface_hub.hf_hub_download(
|
huggingface_hub.hf_hub_download(
|
||||||
utils.koboldai_vars.model,
|
self.model_name,
|
||||||
transformers.configuration_utils.CONFIG_NAME,
|
transformers.configuration_utils.CONFIG_NAME,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
@@ -177,7 +178,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
shutil.move(
|
shutil.move(
|
||||||
os.path.realpath(
|
os.path.realpath(
|
||||||
huggingface_hub.hf_hub_download(
|
huggingface_hub.hf_hub_download(
|
||||||
utils.koboldai_vars.model,
|
self.model_name,
|
||||||
possible_weight_name,
|
possible_weight_name,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
@@ -214,7 +215,7 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
shutil.move(
|
shutil.move(
|
||||||
os.path.realpath(
|
os.path.realpath(
|
||||||
huggingface_hub.hf_hub_download(
|
huggingface_hub.hf_hub_download(
|
||||||
utils.koboldai_vars.model,
|
self.model_name,
|
||||||
filename,
|
filename,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
|
@@ -8,9 +8,13 @@ from modeling.inference_model import InferenceModel
|
|||||||
|
|
||||||
|
|
||||||
class HFInferenceModel(InferenceModel):
|
class HFInferenceModel(InferenceModel):
|
||||||
def __init__(self) -> None:
|
def __init__(self, model_name: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_config = None
|
self.model_config = None
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
def get_local_model_path(
|
def get_local_model_path(
|
||||||
self, legacy: bool = False, ignore_existance: bool = False
|
self, legacy: bool = False, ignore_existance: bool = False
|
||||||
@@ -34,7 +38,7 @@ class HFInferenceModel(InferenceModel):
|
|||||||
# Get the model_type from the config or assume a model type if it isn't present
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
try:
|
try:
|
||||||
self.model_config = AutoConfig.from_pretrained(
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
self.get_local_model_path() or utils.koboldai_vars.model,
|
self.get_local_model_path() or self.model_name,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
)
|
)
|
||||||
|
@@ -30,12 +30,8 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(model_name)
|
||||||
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.tokenizer = None
|
|
||||||
self.model_config = None
|
self.model_config = None
|
||||||
self.capabilties = ModelCapabilities(
|
self.capabilties = ModelCapabilities(
|
||||||
embedding_manipulation=False,
|
embedding_manipulation=False,
|
||||||
|
@@ -61,9 +61,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
lazy_load: bool,
|
lazy_load: bool,
|
||||||
low_mem: bool,
|
low_mem: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(model_name)
|
||||||
|
|
||||||
self.model_name = model_name
|
|
||||||
self.lazy_load = lazy_load
|
self.lazy_load = lazy_load
|
||||||
self.low_mem = low_mem
|
self.low_mem = low_mem
|
||||||
|
|
||||||
@@ -78,8 +76,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
Stoppers.chat_mode_stopper,
|
Stoppers.chat_mode_stopper,
|
||||||
]
|
]
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.tokenizer = None
|
|
||||||
self.capabilties = ModelCapabilities(
|
self.capabilties = ModelCapabilities(
|
||||||
embedding_manipulation=True,
|
embedding_manipulation=True,
|
||||||
post_token_hooks=True,
|
post_token_hooks=True,
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
addopts = --ignore=miniconda3 --ignore=runtime --html=unit_test_report.html --self-contained-html -vv
|
addopts = --ignore=miniconda3 --ignore=runtime --html=unit_test_report.html --self-contained-html --disable-warnings -vv
|
||||||
norecursedirs = .git
|
norecursedirs = .git
|
||||||
|
Reference in New Issue
Block a user