mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Lazyload backends
This commit is contained in:
17
aiserver.py
17
aiserver.py
@@ -586,12 +586,6 @@ utils.socketio = socketio
|
|||||||
|
|
||||||
# Weird import position to steal koboldai_vars from utils
|
# Weird import position to steal koboldai_vars from utils
|
||||||
from modeling.patches import patch_transformers
|
from modeling.patches import patch_transformers
|
||||||
from modeling.inference_models.api import APIInferenceModel
|
|
||||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
|
||||||
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
|
||||||
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
|
||||||
from modeling.inference_models.horde import HordeInferenceModel
|
|
||||||
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
|
||||||
|
|
||||||
|
|
||||||
old_socketio_on = socketio.on
|
old_socketio_on = socketio.on
|
||||||
@@ -1877,12 +1871,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
print(":P")
|
print(":P")
|
||||||
elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]:
|
elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]:
|
||||||
if koboldai_vars.model == "Colab":
|
if koboldai_vars.model == "Colab":
|
||||||
model = APIInferenceModel()
|
from modeling.inference_models.basic_api import BasicAPIInferenceModel
|
||||||
|
model = BasicAPIInferenceModel()
|
||||||
elif koboldai_vars.model == "API":
|
elif koboldai_vars.model == "API":
|
||||||
|
from modeling.inference_models.api import APIInferenceModel
|
||||||
model = APIInferenceModel()
|
model = APIInferenceModel()
|
||||||
elif koboldai_vars.model == "CLUSTER":
|
elif koboldai_vars.model == "CLUSTER":
|
||||||
|
from modeling.inference_models.horde import HordeInferenceModel
|
||||||
model = HordeInferenceModel()
|
model = HordeInferenceModel()
|
||||||
elif koboldai_vars.model == "OAI":
|
elif koboldai_vars.model == "OAI":
|
||||||
|
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||||
model = OpenAIAPIInferenceModel()
|
model = OpenAIAPIInferenceModel()
|
||||||
|
|
||||||
koboldai_vars.colaburl = url or koboldai_vars.colaburl
|
koboldai_vars.colaburl = url or koboldai_vars.colaburl
|
||||||
@@ -1906,11 +1904,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if koboldai_vars.model_type == "gpt2":
|
if koboldai_vars.model_type == "gpt2":
|
||||||
|
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||||
model = CustomGPT2HFTorchInferenceModel(
|
model = CustomGPT2HFTorchInferenceModel(
|
||||||
koboldai_vars.model,
|
koboldai_vars.model,
|
||||||
low_mem=args.lowmem
|
low_mem=args.lowmem
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||||
model = GenericHFTorchInferenceModel(
|
model = GenericHFTorchInferenceModel(
|
||||||
koboldai_vars.model,
|
koboldai_vars.model,
|
||||||
lazy_load=koboldai_vars.lazy_load,
|
lazy_load=koboldai_vars.lazy_load,
|
||||||
@@ -1923,6 +1923,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
||||||
else:
|
else:
|
||||||
# TPU
|
# TPU
|
||||||
|
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||||
model = HFMTJInferenceModel(
|
model = HFMTJInferenceModel(
|
||||||
koboldai_vars.model
|
koboldai_vars.model
|
||||||
)
|
)
|
||||||
@@ -5586,7 +5587,7 @@ def final_startup():
|
|||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
# Precompile TPU backend if required
|
# Precompile TPU backend if required
|
||||||
if isinstance(model, HFMTJInferenceModel):
|
if model and model.capabilties.uses_tpu:
|
||||||
model.raw_generate([23403, 727, 20185], max_new=1)
|
model.raw_generate([23403, 727, 20185], max_new=1)
|
||||||
|
|
||||||
# Set the initial RNG seed
|
# Set the initial RNG seed
|
||||||
|
@@ -156,6 +156,9 @@ class ModelCapabilities:
|
|||||||
# Some models cannot be hosted over the API, namely the API itself.
|
# Some models cannot be hosted over the API, namely the API itself.
|
||||||
api_host: bool = True
|
api_host: bool = True
|
||||||
|
|
||||||
|
# Some models need to warm up the TPU before use
|
||||||
|
uses_tpu: bool = False
|
||||||
|
|
||||||
|
|
||||||
class InferenceModel:
|
class InferenceModel:
|
||||||
"""Root class for all models."""
|
"""Root class for all models."""
|
||||||
|
@@ -38,6 +38,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
post_token_hooks=False,
|
post_token_hooks=False,
|
||||||
stopper_hooks=False,
|
stopper_hooks=False,
|
||||||
post_token_probs=False,
|
post_token_probs=False,
|
||||||
|
uses_tpu=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_mtj(self) -> None:
|
def setup_mtj(self) -> None:
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# We have to go through aiserver to initalize koboldai_vars :(
|
# We have to go through aiserver to initalize koboldai_vars :(
|
||||||
from aiserver import GenericHFTorchInferenceModel
|
|
||||||
from aiserver import koboldai_vars
|
from aiserver import koboldai_vars
|
||||||
|
|
||||||
from modeling.inference_model import InferenceModel
|
from modeling.inference_model import InferenceModel
|
||||||
from modeling.inference_models.api import APIInferenceModel
|
from modeling.inference_models.api import APIInferenceModel
|
||||||
|
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||||
from modeling.inference_models.horde import HordeInferenceModel
|
from modeling.inference_models.horde import HordeInferenceModel
|
||||||
|
|
||||||
model: InferenceModel
|
model: InferenceModel
|
||||||
|
@@ -42,13 +42,10 @@ import utils
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
try:
|
if utils.koboldai_vars.use_colab_tpu:
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
except ImportError as e:
|
|
||||||
if utils.koboldai_vars.use_colab_tpu:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def update_settings():
|
def update_settings():
|
||||||
|
Reference in New Issue
Block a user