mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Lazyload backends
This commit is contained in:
17
aiserver.py
17
aiserver.py
@@ -586,12 +586,6 @@ utils.socketio = socketio
|
||||
|
||||
# Weird import position to steal koboldai_vars from utils
|
||||
from modeling.patches import patch_transformers
|
||||
from modeling.inference_models.api import APIInferenceModel
|
||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||
from modeling.inference_models.horde import HordeInferenceModel
|
||||
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||
|
||||
|
||||
old_socketio_on = socketio.on
|
||||
@@ -1877,12 +1871,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
print(":P")
|
||||
elif koboldai_vars.model in ["Colab", "API", "CLUSTER", "OAI"]:
|
||||
if koboldai_vars.model == "Colab":
|
||||
model = APIInferenceModel()
|
||||
from modeling.inference_models.basic_api import BasicAPIInferenceModel
|
||||
model = BasicAPIInferenceModel()
|
||||
elif koboldai_vars.model == "API":
|
||||
from modeling.inference_models.api import APIInferenceModel
|
||||
model = APIInferenceModel()
|
||||
elif koboldai_vars.model == "CLUSTER":
|
||||
from modeling.inference_models.horde import HordeInferenceModel
|
||||
model = HordeInferenceModel()
|
||||
elif koboldai_vars.model == "OAI":
|
||||
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||
model = OpenAIAPIInferenceModel()
|
||||
|
||||
koboldai_vars.colaburl = url or koboldai_vars.colaburl
|
||||
@@ -1906,11 +1904,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
pass
|
||||
|
||||
if koboldai_vars.model_type == "gpt2":
|
||||
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||
model = CustomGPT2HFTorchInferenceModel(
|
||||
koboldai_vars.model,
|
||||
low_mem=args.lowmem
|
||||
)
|
||||
else:
|
||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||
model = GenericHFTorchInferenceModel(
|
||||
koboldai_vars.model,
|
||||
lazy_load=koboldai_vars.lazy_load,
|
||||
@@ -1923,6 +1923,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
||||
else:
|
||||
# TPU
|
||||
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||
model = HFMTJInferenceModel(
|
||||
koboldai_vars.model
|
||||
)
|
||||
@@ -5586,7 +5587,7 @@ def final_startup():
|
||||
file.close()
|
||||
|
||||
# Precompile TPU backend if required
|
||||
if isinstance(model, HFMTJInferenceModel):
|
||||
if model and model.capabilties.uses_tpu:
|
||||
model.raw_generate([23403, 727, 20185], max_new=1)
|
||||
|
||||
# Set the initial RNG seed
|
||||
|
Reference in New Issue
Block a user