Model: And another refactor

This commit is contained in:
somebody
2023-03-01 19:16:35 -06:00
parent 225dcf1a0a
commit 54cecd4d5d
18 changed files with 3045 additions and 2911 deletions

View File

@@ -0,0 +1,52 @@
import os
from typing import Optional
from transformers import AutoConfig
import utils
from logger import logger
from modeling.inference_model import InferenceModel
class HFInferenceModel(InferenceModel):
def __init__(self) -> None:
super().__init__()
self.model_config = None
def get_local_model_path(
self, legacy: bool = False, ignore_existance: bool = False
) -> Optional[str]:
"""
Returns a string of the model's path locally, or None if it is not downloaded.
If ignore_existance is true, it will always return a path.
"""
basename = utils.koboldai_vars.model.replace("/", "_")
if legacy:
ret = basename
else:
ret = os.path.join("models", basename)
if os.path.isdir(ret) or ignore_existance:
return ret
return None
def init_model_config(self) -> None:
# Get the model_type from the config or assume a model type if it isn't present
try:
self.model_config = AutoConfig.from_pretrained(
self.get_local_model_path() or utils.koboldai_vars.model,
revision=utils.koboldai_vars.revision,
cache_dir="cache",
)
utils.koboldai_vars.model_type = self.model_config.model_type
except ValueError:
utils.koboldai_vars.model_type = {
"NeoCustom": "gpt_neo",
"GPT2Custom": "gpt2",
}.get(utils.koboldai_vars.model)
if not utils.koboldai_vars.model_type:
logger.warning(
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
)
utils.koboldai_vars.model_type = "gpt_neo"