mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge latest changes, fix conflict
This commit is contained in:
@@ -6,13 +6,12 @@ import torch
|
||||
import shutil
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from modeling.inference_model import SuperLegacyModelError
|
||||
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM, GPT2LMHeadModel
|
||||
|
||||
import utils
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
import koboldai_settings
|
||||
from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
from logger import logger
|
||||
|
||||
try:
|
||||
import breakmodel
|
||||
@@ -80,17 +79,12 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
):
|
||||
try:
|
||||
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||
utils.module_names = list(metamodel.state_dict().keys())
|
||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||
except Exception as e:
|
||||
logger.error(f"Fell back to neo for metamodel due to {e}")
|
||||
try:
|
||||
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Falling back again due to {e}")
|
||||
raise SuperLegacyModelError
|
||||
|
||||
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||
utils.module_names = list(metamodel.state_dict().keys())
|
||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||
logger.warning(f"Gave up on lazy loading due to {e}")
|
||||
self.lazy_load = False
|
||||
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
with self._maybe_use_float16(), lazy_loader.use_lazy_load(
|
||||
|
@@ -24,6 +24,24 @@ class HFInferenceModel(InferenceModel):
|
||||
If ignore_existance is true, it will always return a path.
|
||||
"""
|
||||
|
||||
if self.model_name in ["NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]:
|
||||
model_path = utils.koboldai_vars.custmodpth
|
||||
assert model_path
|
||||
|
||||
# Path can be absolute or relative to models directory
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
model_path = os.path.join("models", model_path)
|
||||
|
||||
try:
|
||||
assert os.path.exists(model_path)
|
||||
except AssertionError:
|
||||
logger.error(f"Custom model does not exist at '{utils.koboldai_vars.custmodpth}' or '{model_path}'.")
|
||||
raise
|
||||
|
||||
return model_path
|
||||
|
||||
basename = utils.koboldai_vars.model.replace("/", "_")
|
||||
if legacy:
|
||||
ret = basename
|
||||
|
@@ -265,7 +265,6 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
|
||||
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||
logger.info(f"dynamic_inference={dynamic_inference}")
|
||||
|
||||
if seed is not None:
|
||||
tpu_mtj_backend.set_rng_seed(seed)
|
||||
|
@@ -18,6 +18,7 @@ import transformers
|
||||
from transformers import (
|
||||
StoppingCriteria,
|
||||
GPTNeoForCausalLM,
|
||||
GPT2LMHeadModel,
|
||||
AutoModelForCausalLM,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
@@ -131,10 +132,14 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
if not utils.koboldai_vars.model_type:
|
||||
utils.koboldai_vars.model_type = m_self.get_model_type()
|
||||
|
||||
# Model specific overrides if a model has bad defaults
|
||||
# These are model specific overrides if a model has bad defaults
|
||||
if utils.koboldai_vars.model_type == "llama":
|
||||
m_self.tokenizer.decode_with_prefix_space = True
|
||||
m_self.tokenizer.add_bos_token = False
|
||||
elif utils.koboldai_vars.model_type == "opt":
|
||||
m_self.tokenizer._koboldai_header = m_self.tokenizer.encode("")
|
||||
m_self.tokenizer.add_bos_token = False
|
||||
m_self.tokenizer.add_prefix_space = False
|
||||
|
||||
# Patch stopping_criteria
|
||||
class PTHStopper(StoppingCriteria):
|
||||
@@ -265,27 +270,33 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
)
|
||||
|
||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||
try:
|
||||
return AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
**tf_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
||||
tf_kwargs["revision"] = utils.koboldai_vars.revision
|
||||
tf_kwargs["cache_dir"] = "cache"
|
||||
|
||||
# If we have model hints for legacy model, use them rather than fall back.
|
||||
try:
|
||||
if self.model_name == "GPT2Custom":
|
||||
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||
elif self.model_name == "NeoCustom":
|
||||
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.")
|
||||
|
||||
# Try to determine model type from either AutoModel or falling back to legacy
|
||||
try:
|
||||
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
if "out of memory" in traceback.format_exc().lower():
|
||||
raise RuntimeError(
|
||||
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
|
||||
)
|
||||
|
||||
return GPTNeoForCausalLM.from_pretrained(
|
||||
location,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
**tf_kwargs,
|
||||
)
|
||||
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
||||
try:
|
||||
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Fell back to GPTNeoForCausalLM due to {e}")
|
||||
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.model.get_input_embeddings().embedding_dim
|
||||
|
@@ -1,74 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
import utils
|
||||
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
||||
|
||||
|
||||
class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.lazy_load = False
|
||||
|
||||
model_path = None
|
||||
|
||||
for possible_config_path in [
|
||||
utils.koboldai_vars.custmodpth,
|
||||
os.path.join("models", utils.koboldai_vars.custmodpth),
|
||||
self.model_name
|
||||
]:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(possible_config_path, "config.json"), "r"
|
||||
) as file:
|
||||
self.model_config = json.load(file)
|
||||
model_path = possible_config_path
|
||||
break
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
if not model_path:
|
||||
raise RuntimeError("Empty model_path!")
|
||||
|
||||
with self._maybe_use_float16():
|
||||
try:
|
||||
self.model = GPT2LMHeadModel.from_pretrained(
|
||||
model_path,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
local_files_only=True
|
||||
)
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
model_path,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
except Exception as e:
|
||||
if "out of memory" in traceback.format_exc().lower():
|
||||
raise RuntimeError(
|
||||
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
|
||||
) from e
|
||||
raise e
|
||||
|
||||
if save_model:
|
||||
self.model.save_pretrained(
|
||||
self.get_local_model_path(ignore_existance=True),
|
||||
max_shard_size="500MiB",
|
||||
)
|
||||
self.tokenizer.save_pretrained(
|
||||
self.get_local_model_path(ignore_existance=True)
|
||||
)
|
||||
|
||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||
|
||||
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||
if utils.koboldai_vars.hascuda and utils.koboldai_vars.usegpu:
|
||||
self.model = self.model.half().to(utils.koboldai_vars.gpu_device)
|
||||
else:
|
||||
self.model = self.model.to("cpu").float()
|
||||
|
||||
self.patch_embedding()
|
Reference in New Issue
Block a user