diff --git a/modeling/inference_models/4bit_hf_torch/class.py b/modeling/inference_models/4bit_hf_torch/class.py index 62f04bfb..7d7dfc00 100644 --- a/modeling/inference_models/4bit_hf_torch/class.py +++ b/modeling/inference_models/4bit_hf_torch/class.py @@ -10,6 +10,7 @@ import sys from typing import Union from transformers import GPTNeoForCausalLM, AutoTokenizer, LlamaTokenizer +import hf_bleeding_edge from hf_bleeding_edge import AutoModelForCausalLM import utils @@ -37,6 +38,13 @@ from gptq.opt import load_quant as opt_load_quant from gptq.mpt import load_quant as mpt_load_quant from gptq.offload import load_quant_offload +autogptq_support = True +try: + import auto_gptq + from auto_gptq import AutoGPTQForCausalLM +except ImportError: + autogptq_support = False + model_backend_name = "Huggingface GPTQ" @@ -212,6 +220,26 @@ class model_backend(HFTorchInferenceModel): model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list) elif model_type == "mpt": model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list) + elif autogptq_support: + # Monkey patch in hf_bleeding_edge to avoid having to trust remote code + auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig + auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig + auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM + model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors")) + + # Patch in embeddings function + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + type(model).get_input_embeddings = get_input_embeddings + + # Patch in args support.. + def generate(self, *args, **kwargs): + """shortcut for model.generate""" + with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): + return self.model.generate(*args, **kwargs) + + type(model).generate = generate else: raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")