diff --git a/aiserver.py b/aiserver.py index b1a16447..b6858d53 100644 --- a/aiserver.py +++ b/aiserver.py @@ -586,12 +586,6 @@ utils.socketio = socketio # Weird import position to steal koboldai_vars from utils from modeling.patches import patch_transformers -from modeling.inference_models.api import APIInferenceModel -from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel -from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel -from modeling.inference_models.hf_mtj import HFMTJInferenceModel -from modeling.inference_models.horde import HordeInferenceModel -from modeling.inference_models.openai import OpenAIAPIInferenceModel old_socketio_on = socketio.on @@ -1877,12 +1871,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal print(":P") elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]: if koboldai_vars.model == "Colab": - model = APIInferenceModel() + from modeling.inference_models.basic_api import BasicAPIInferenceModel + model = BasicAPIInferenceModel() elif koboldai_vars.model == "API": + from modeling.inference_models.api import APIInferenceModel model = APIInferenceModel() elif koboldai_vars.model == "CLUSTER": + from modeling.inference_models.horde import HordeInferenceModel model = HordeInferenceModel() elif koboldai_vars.model == "OAI": + from modeling.inference_models.openai import OpenAIAPIInferenceModel model = OpenAIAPIInferenceModel() koboldai_vars.colaburl = url or koboldai_vars.colaburl @@ -1906,11 +1904,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal pass if koboldai_vars.model_type == "gpt2": + from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel model = CustomGPT2HFTorchInferenceModel( koboldai_vars.model, low_mem=args.lowmem ) else: + from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel model = GenericHFTorchInferenceModel( koboldai_vars.model, lazy_load=koboldai_vars.lazy_load, @@ -1923,6 +1923,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal logger.info(f"Pipeline created: {koboldai_vars.model}") else: # TPU + from modeling.inference_models.hf_mtj import HFMTJInferenceModel model = HFMTJInferenceModel( koboldai_vars.model ) @@ -5586,7 +5587,7 @@ def final_startup(): file.close() # Precompile TPU backend if required - if isinstance(model, HFMTJInferenceModel): + if model and model.capabilties.uses_tpu: model.raw_generate([23403, 727, 20185], max_new=1) # Set the initial RNG seed diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 6a08b0d5..4eb63618 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -156,6 +156,9 @@ class ModelCapabilities: # Some models cannot be hosted over the API, namely the API itself. api_host: bool = True + # Some models need to warm up the TPU before use + uses_tpu: bool = False + class InferenceModel: """Root class for all models.""" diff --git a/modeling/inference_models/hf_mtj.py b/modeling/inference_models/hf_mtj.py index f8993f56..39095976 100644 --- a/modeling/inference_models/hf_mtj.py +++ b/modeling/inference_models/hf_mtj.py @@ -38,6 +38,7 @@ class HFMTJInferenceModel(HFInferenceModel): post_token_hooks=False, stopper_hooks=False, post_token_probs=False, + uses_tpu=True ) def setup_mtj(self) -> None: diff --git a/modeling/test_generation.py b/modeling/test_generation.py index 947a83c4..0f700d0b 100644 --- a/modeling/test_generation.py +++ b/modeling/test_generation.py @@ -1,11 +1,11 @@ import torch # We have to go through aiserver to initalize koboldai_vars :( -from aiserver import GenericHFTorchInferenceModel from aiserver import koboldai_vars from modeling.inference_model import InferenceModel from modeling.inference_models.api import APIInferenceModel +from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel from modeling.inference_models.horde import HordeInferenceModel model: InferenceModel diff --git a/modeling/warpers.py b/modeling/warpers.py index 1f51a5d9..4c7dbac4 100644 --- a/modeling/warpers.py +++ b/modeling/warpers.py @@ -42,13 +42,10 @@ import utils import torch import numpy as np -try: +if utils.koboldai_vars.use_colab_tpu: import jax import jax.numpy as jnp import tpu_mtj_backend -except ImportError as e: - if utils.koboldai_vars.use_colab_tpu: - raise e def update_settings():