mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fall back to autogptq if available and model not supported by gptq-koboldai
This commit is contained in:
@@ -10,6 +10,7 @@ import sys
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from transformers import GPTNeoForCausalLM, AutoTokenizer, LlamaTokenizer
|
from transformers import GPTNeoForCausalLM, AutoTokenizer, LlamaTokenizer
|
||||||
|
import hf_bleeding_edge
|
||||||
from hf_bleeding_edge import AutoModelForCausalLM
|
from hf_bleeding_edge import AutoModelForCausalLM
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
@@ -37,6 +38,13 @@ from gptq.opt import load_quant as opt_load_quant
|
|||||||
from gptq.mpt import load_quant as mpt_load_quant
|
from gptq.mpt import load_quant as mpt_load_quant
|
||||||
from gptq.offload import load_quant_offload
|
from gptq.offload import load_quant_offload
|
||||||
|
|
||||||
|
autogptq_support = True
|
||||||
|
try:
|
||||||
|
import auto_gptq
|
||||||
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
autogptq_support = False
|
||||||
|
|
||||||
|
|
||||||
model_backend_name = "Huggingface GPTQ"
|
model_backend_name = "Huggingface GPTQ"
|
||||||
|
|
||||||
@@ -212,6 +220,26 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list)
|
model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list)
|
||||||
elif model_type == "mpt":
|
elif model_type == "mpt":
|
||||||
model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list)
|
model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list)
|
||||||
|
elif autogptq_support:
|
||||||
|
# Monkey patch in hf_bleeding_edge to avoid having to trust remote code
|
||||||
|
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||||
|
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||||
|
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
|
||||||
|
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"))
|
||||||
|
|
||||||
|
# Patch in embeddings function
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
type(model).get_input_embeddings = get_input_embeddings
|
||||||
|
|
||||||
|
# Patch in args support..
|
||||||
|
def generate(self, *args, **kwargs):
|
||||||
|
"""shortcut for model.generate"""
|
||||||
|
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
|
||||||
|
return self.model.generate(*args, **kwargs)
|
||||||
|
|
||||||
|
type(model).generate = generate
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")
|
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user