mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Lazyload backends
This commit is contained in:
17
aiserver.py
17
aiserver.py
@@ -586,12 +586,6 @@ utils.socketio = socketio
|
||||
|
||||
# Weird import position to steal koboldai_vars from utils
|
||||
from modeling.patches import patch_transformers
|
||||
from modeling.inference_models.api import APIInferenceModel
|
||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||
from modeling.inference_models.horde import HordeInferenceModel
|
||||
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||
|
||||
|
||||
old_socketio_on = socketio.on
|
||||
@@ -1877,12 +1871,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
print(":P")
|
||||
elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]:
|
||||
if koboldai_vars.model == "Colab":
|
||||
model = APIInferenceModel()
|
||||
from modeling.inference_models.basic_api import BasicAPIInferenceModel
|
||||
model = BasicAPIInferenceModel()
|
||||
elif koboldai_vars.model == "API":
|
||||
from modeling.inference_models.api import APIInferenceModel
|
||||
model = APIInferenceModel()
|
||||
elif koboldai_vars.model == "CLUSTER":
|
||||
from modeling.inference_models.horde import HordeInferenceModel
|
||||
model = HordeInferenceModel()
|
||||
elif koboldai_vars.model == "OAI":
|
||||
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||
model = OpenAIAPIInferenceModel()
|
||||
|
||||
koboldai_vars.colaburl = url or koboldai_vars.colaburl
|
||||
@@ -1906,11 +1904,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
pass
|
||||
|
||||
if koboldai_vars.model_type == "gpt2":
|
||||
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||
model = CustomGPT2HFTorchInferenceModel(
|
||||
koboldai_vars.model,
|
||||
low_mem=args.lowmem
|
||||
)
|
||||
else:
|
||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||
model = GenericHFTorchInferenceModel(
|
||||
koboldai_vars.model,
|
||||
lazy_load=koboldai_vars.lazy_load,
|
||||
@@ -1923,6 +1923,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
||||
else:
|
||||
# TPU
|
||||
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||
model = HFMTJInferenceModel(
|
||||
koboldai_vars.model
|
||||
)
|
||||
@@ -5586,7 +5587,7 @@ def final_startup():
|
||||
file.close()
|
||||
|
||||
# Precompile TPU backend if required
|
||||
if isinstance(model, HFMTJInferenceModel):
|
||||
if model and model.capabilties.uses_tpu:
|
||||
model.raw_generate([23403, 727, 20185], max_new=1)
|
||||
|
||||
# Set the initial RNG seed
|
||||
|
@@ -156,6 +156,9 @@ class ModelCapabilities:
|
||||
# Some models cannot be hosted over the API, namely the API itself.
|
||||
api_host: bool = True
|
||||
|
||||
# Some models need to warm up the TPU before use
|
||||
uses_tpu: bool = False
|
||||
|
||||
|
||||
class InferenceModel:
|
||||
"""Root class for all models."""
|
||||
|
@@ -38,6 +38,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
post_token_hooks=False,
|
||||
stopper_hooks=False,
|
||||
post_token_probs=False,
|
||||
uses_tpu=True
|
||||
)
|
||||
|
||||
def setup_mtj(self) -> None:
|
||||
|
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
|
||||
# We have to go through aiserver to initalize koboldai_vars :(
|
||||
from aiserver import GenericHFTorchInferenceModel
|
||||
from aiserver import koboldai_vars
|
||||
|
||||
from modeling.inference_model import InferenceModel
|
||||
from modeling.inference_models.api import APIInferenceModel
|
||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||
from modeling.inference_models.horde import HordeInferenceModel
|
||||
|
||||
model: InferenceModel
|
||||
|
@@ -42,13 +42,10 @@ import utils
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tpu_mtj_backend
|
||||
except ImportError as e:
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
|
||||
def update_settings():
|
||||
|
Reference in New Issue
Block a user