Fall back to autogptq if available and model not supported by gptq-koboldai

This commit is contained in:
0cc4m
2023-06-04 08:06:48 +02:00
parent cf886de18b
commit b7838c7dde

View File

@@ -10,6 +10,7 @@ import sys
from typing import Union from typing import Union
from transformers import GPTNeoForCausalLM, AutoTokenizer, LlamaTokenizer from transformers import GPTNeoForCausalLM, AutoTokenizer, LlamaTokenizer
import hf_bleeding_edge
from hf_bleeding_edge import AutoModelForCausalLM from hf_bleeding_edge import AutoModelForCausalLM
import utils import utils
@@ -37,6 +38,13 @@ from gptq.opt import load_quant as opt_load_quant
from gptq.mpt import load_quant as mpt_load_quant from gptq.mpt import load_quant as mpt_load_quant
from gptq.offload import load_quant_offload from gptq.offload import load_quant_offload
autogptq_support = True
try:
import auto_gptq
from auto_gptq import AutoGPTQForCausalLM
except ImportError:
autogptq_support = False
model_backend_name = "Huggingface GPTQ" model_backend_name = "Huggingface GPTQ"
@@ -212,6 +220,26 @@ class model_backend(HFTorchInferenceModel):
model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list) model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list)
elif model_type == "mpt": elif model_type == "mpt":
model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list) model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list)
elif autogptq_support:
# Monkey patch in hf_bleeding_edge to avoid having to trust remote code
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"))
# Patch in embeddings function
def get_input_embeddings(self):
return self.model.get_input_embeddings()
type(model).get_input_embeddings = get_input_embeddings
# Patch in args support..
def generate(self, *args, **kwargs):
"""shortcut for model.generate"""
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(*args, **kwargs)
type(model).generate = generate
else: else:
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit") raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")