mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix legacy model loading
This commit is contained in:
25
aiserver.py
25
aiserver.py
@@ -11,6 +11,8 @@ from enum import Enum
|
|||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import eventlet
|
import eventlet
|
||||||
|
|
||||||
|
from modeling.inference_model import SuperLegacyModelError
|
||||||
eventlet.monkey_patch(all=True, thread=False, os=False)
|
eventlet.monkey_patch(all=True, thread=False, os=False)
|
||||||
import os, inspect
|
import os, inspect
|
||||||
os.system("")
|
os.system("")
|
||||||
@@ -1942,24 +1944,31 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if koboldai_vars.model_type == "gpt2":
|
try:
|
||||||
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
|
||||||
model = CustomGPT2HFTorchInferenceModel(
|
|
||||||
koboldai_vars.model,
|
|
||||||
lazy_load=koboldai_vars.lazy_load,
|
|
||||||
low_mem=args.lowmem
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||||
model = GenericHFTorchInferenceModel(
|
model = GenericHFTorchInferenceModel(
|
||||||
koboldai_vars.model,
|
koboldai_vars.model,
|
||||||
lazy_load=koboldai_vars.lazy_load,
|
lazy_load=koboldai_vars.lazy_load,
|
||||||
low_mem=args.lowmem
|
low_mem=args.lowmem
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load(
|
model.load(
|
||||||
save_model=not (args.colab or args.cacheonly) or args.savemodel,
|
save_model=not (args.colab or args.cacheonly) or args.savemodel,
|
||||||
initial_load=initial_load,
|
initial_load=initial_load,
|
||||||
)
|
)
|
||||||
|
except SuperLegacyModelError:
|
||||||
|
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||||
|
model = CustomGPT2HFTorchInferenceModel(
|
||||||
|
koboldai_vars.model,
|
||||||
|
lazy_load=koboldai_vars.lazy_load,
|
||||||
|
low_mem=args.lowmem
|
||||||
|
)
|
||||||
|
|
||||||
|
model.load(
|
||||||
|
save_model=not (args.colab or args.cacheonly) or args.savemodel,
|
||||||
|
initial_load=initial_load,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
||||||
else:
|
else:
|
||||||
# TPU
|
# TPU
|
||||||
|
@@ -17,6 +17,9 @@ from modeling import logits_processors
|
|||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
|
class SuperLegacyModelError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
# We only want to use logit manipulations and such on our core text model
|
# We only want to use logit manipulations and such on our core text model
|
||||||
class use_core_manipulations:
|
class use_core_manipulations:
|
||||||
"""Use in a `with` block to patch functions for core story model sampling."""
|
"""Use in a `with` block to patch functions for core story model sampling."""
|
||||||
|
@@ -7,6 +7,7 @@ import shutil
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
|
from modeling.inference_model import SuperLegacyModelError
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import modeling.lazy_loader as lazy_loader
|
import modeling.lazy_loader as lazy_loader
|
||||||
@@ -81,7 +82,12 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fell back to neo for metamodel due to {e}")
|
logger.error(f"Fell back to neo for metamodel due to {e}")
|
||||||
|
try:
|
||||||
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
|
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Falling back again due to {e}")
|
||||||
|
raise SuperLegacyModelError
|
||||||
|
|
||||||
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||||
utils.module_names = list(metamodel.state_dict().keys())
|
utils.module_names = list(metamodel.state_dict().keys())
|
||||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||||
|
@@ -19,6 +19,7 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
for possible_config_path in [
|
for possible_config_path in [
|
||||||
utils.koboldai_vars.custmodpth,
|
utils.koboldai_vars.custmodpth,
|
||||||
os.path.join("models", utils.koboldai_vars.custmodpth),
|
os.path.join("models", utils.koboldai_vars.custmodpth),
|
||||||
|
self.model_name
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
with open(
|
with open(
|
||||||
@@ -36,12 +37,13 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
with self._maybe_use_float16():
|
with self._maybe_use_float16():
|
||||||
try:
|
try:
|
||||||
self.model = GPT2LMHeadModel.from_pretrained(
|
self.model = GPT2LMHeadModel.from_pretrained(
|
||||||
utils.koboldai_vars.custmodpth,
|
model_path,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||||
utils.koboldai_vars.custmodpth,
|
model_path,
|
||||||
revision=utils.koboldai_vars.revision,
|
revision=utils.koboldai_vars.revision,
|
||||||
cache_dir="cache",
|
cache_dir="cache",
|
||||||
)
|
)
|
||||||
@@ -69,4 +71,4 @@ class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
|||||||
else:
|
else:
|
||||||
self.model = self.model.to("cpu").float()
|
self.model = self.model.to("cpu").float()
|
||||||
|
|
||||||
self.patch_causal_lm()
|
self.patch_embedding()
|
||||||
|
Reference in New Issue
Block a user