Model: Fix assorted bugs

and ignore warnings in pytest
This commit is contained in:
somebody
2023-03-09 20:59:27 -06:00
parent 3646aa9e83
commit 8c8bdfaf6a
6 changed files with 17 additions and 23 deletions

View File

@@ -8,9 +8,13 @@ from modeling.inference_model import InferenceModel
class HFInferenceModel(InferenceModel):
def __init__(self) -> None:
def __init__(self, model_name: str) -> None:
super().__init__()
self.model_config = None
self.model_name = model_name
self.model = None
self.tokenizer = None
def get_local_model_path(
self, legacy: bool = False, ignore_existance: bool = False
@@ -34,7 +38,7 @@ class HFInferenceModel(InferenceModel):
# Get the model_type from the config or assume a model type if it isn't present
try:
self.model_config = AutoConfig.from_pretrained(
self.get_local_model_path() or utils.koboldai_vars.model,
self.get_local_model_path() or self.model_name,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
)