diff --git a/aiserver.py b/aiserver.py index d6c0f96c..bb09a20c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -11,6 +11,8 @@ from enum import Enum import random import shutil import eventlet + +from modeling.inference_model import SuperLegacyModelError eventlet.monkey_patch(all=True, thread=False, os=False) import os, inspect os.system("") @@ -1942,24 +1944,31 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal except: pass - if koboldai_vars.model_type == "gpt2": - from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel - model = CustomGPT2HFTorchInferenceModel( - koboldai_vars.model, - lazy_load=koboldai_vars.lazy_load, - low_mem=args.lowmem - ) - else: + try: from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel model = GenericHFTorchInferenceModel( koboldai_vars.model, lazy_load=koboldai_vars.lazy_load, low_mem=args.lowmem ) - model.load( - save_model=not (args.colab or args.cacheonly) or args.savemodel, - initial_load=initial_load, - ) + + model.load( + save_model=not (args.colab or args.cacheonly) or args.savemodel, + initial_load=initial_load, + ) + except SuperLegacyModelError: + from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel + model = CustomGPT2HFTorchInferenceModel( + koboldai_vars.model, + lazy_load=koboldai_vars.lazy_load, + low_mem=args.lowmem + ) + + model.load( + save_model=not (args.colab or args.cacheonly) or args.savemodel, + initial_load=initial_load, + ) + logger.info(f"Pipeline created: {koboldai_vars.model}") else: # TPU diff --git a/modeling/inference_model.py b/modeling/inference_model.py index a8f1f6f7..94110ee5 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -17,6 +17,9 @@ from modeling import logits_processors import utils +class SuperLegacyModelError(RuntimeError): + pass + # We only want to use logit manipulations and such on our core text model class use_core_manipulations: """Use in a `with` block to patch functions for core story model sampling.""" diff --git a/modeling/inference_models/generic_hf_torch.py b/modeling/inference_models/generic_hf_torch.py index 7598a424..a9fc6370 100644 --- a/modeling/inference_models/generic_hf_torch.py +++ b/modeling/inference_models/generic_hf_torch.py @@ -7,6 +7,7 @@ import shutil from typing import Union from transformers import AutoModelForCausalLM, GPTNeoForCausalLM +from modeling.inference_model import SuperLegacyModelError import utils import modeling.lazy_loader as lazy_loader @@ -81,7 +82,12 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): metamodel = AutoModelForCausalLM.from_config(self.model_config) except Exception as e: logger.error(f"Fell back to neo for metamodel due to {e}") - metamodel = GPTNeoForCausalLM.from_config(self.model_config) + try: + metamodel = GPTNeoForCausalLM.from_config(self.model_config) + except Exception as e: + logger.error(f"Falling back again due to {e}") + raise SuperLegacyModelError + utils.layers_module_names = utils.get_layers_module_names(metamodel) utils.module_names = list(metamodel.state_dict().keys()) utils.named_buffers = list(metamodel.named_buffers(recurse=True)) diff --git a/modeling/inference_models/legacy_gpt2_hf.py b/modeling/inference_models/legacy_gpt2_hf.py index b710ac9f..9bcdde95 100644 --- a/modeling/inference_models/legacy_gpt2_hf.py +++ b/modeling/inference_models/legacy_gpt2_hf.py @@ -19,6 +19,7 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel): for possible_config_path in [ utils.koboldai_vars.custmodpth, os.path.join("models", utils.koboldai_vars.custmodpth), + self.model_name ]: try: with open( @@ -36,12 +37,13 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel): with self._maybe_use_float16(): try: self.model = GPT2LMHeadModel.from_pretrained( - utils.koboldai_vars.custmodpth, + model_path, revision=utils.koboldai_vars.revision, cache_dir="cache", + local_files_only=True ) self.tokenizer = GPT2Tokenizer.from_pretrained( - utils.koboldai_vars.custmodpth, + model_path, revision=utils.koboldai_vars.revision, cache_dir="cache", ) @@ -69,4 +71,4 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel): else: self.model = self.model.to("cpu").float() - self.patch_causal_lm() + self.patch_embedding()