Model: Lazyload backends

This commit is contained in:
somebody
2023-03-13 20:29:29 -05:00
parent adc11fdbc9
commit b93c339145
5 changed files with 15 additions and 13 deletions

View File

@@ -586,12 +586,6 @@ utils.socketio = socketio
# Weird import position to steal koboldai_vars from utils
from modeling.patches import patch_transformers
from modeling.inference_models.api import APIInferenceModel
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
from modeling.inference_models.horde import HordeInferenceModel
from modeling.inference_models.openai import OpenAIAPIInferenceModel
old_socketio_on = socketio.on
@@ -1877,12 +1871,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
print(":P")
elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]:
if koboldai_vars.model == "Colab":
model = APIInferenceModel()
from modeling.inference_models.basic_api import BasicAPIInferenceModel
model = BasicAPIInferenceModel()
elif koboldai_vars.model == "API":
from modeling.inference_models.api import APIInferenceModel
model = APIInferenceModel()
elif koboldai_vars.model == "CLUSTER":
from modeling.inference_models.horde import HordeInferenceModel
model = HordeInferenceModel()
elif koboldai_vars.model == "OAI":
from modeling.inference_models.openai import OpenAIAPIInferenceModel
model = OpenAIAPIInferenceModel()
koboldai_vars.colaburl = url or koboldai_vars.colaburl
@@ -1906,11 +1904,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
pass
if koboldai_vars.model_type == "gpt2":
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
model = CustomGPT2HFTorchInferenceModel(
koboldai_vars.model,
low_mem=args.lowmem
)
else:
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
model = GenericHFTorchInferenceModel(
koboldai_vars.model,
lazy_load=koboldai_vars.lazy_load,
@@ -1923,6 +1923,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
logger.info(f"Pipeline created: {koboldai_vars.model}")
else:
# TPU
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
model = HFMTJInferenceModel(
koboldai_vars.model
)
@@ -5586,7 +5587,7 @@ def final_startup():
file.close()
# Precompile TPU backend if required
if isinstance(model, HFMTJInferenceModel):
if model and model.capabilties.uses_tpu:
model.raw_generate([23403, 727, 20185], max_new=1)
# Set the initial RNG seed

View File

@@ -156,6 +156,9 @@ class ModelCapabilities:
# Some models cannot be hosted over the API, namely the API itself.
api_host: bool = True
# Some models need to warm up the TPU before use
uses_tpu: bool = False
class InferenceModel:
"""Root class for all models."""

View File

@@ -38,6 +38,7 @@ class HFMTJInferenceModel(HFInferenceModel):
post_token_hooks=False,
stopper_hooks=False,
post_token_probs=False,
uses_tpu=True
)
def setup_mtj(self) -> None:

View File

@@ -1,11 +1,11 @@
import torch
# We have to go through aiserver to initalize koboldai_vars :(
from aiserver import GenericHFTorchInferenceModel
from aiserver import koboldai_vars
from modeling.inference_model import InferenceModel
from modeling.inference_models.api import APIInferenceModel
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
from modeling.inference_models.horde import HordeInferenceModel
model: InferenceModel

View File

@@ -42,13 +42,10 @@ import utils
import torch
import numpy as np
try:
if utils.koboldai_vars.use_colab_tpu:
import jax
import jax.numpy as jnp
import tpu_mtj_backend
except ImportError as e:
if utils.koboldai_vars.use_colab_tpu:
raise e
def update_settings():