diff --git a/modeling/inference_models/gptq_hf_torch/class.py b/modeling/inference_models/gptq_hf_torch/class.py index eb3d2475..eff71bc0 100644 --- a/modeling/inference_models/gptq_hf_torch/class.py +++ b/modeling/inference_models/gptq_hf_torch/class.py @@ -169,6 +169,7 @@ class model_backend(HFTorchInferenceModel): self.init_model_config() self.lazy_load = True + self.implementation = "occam" gpulayers = self.breakmodel_config.gpu_blocks @@ -323,22 +324,25 @@ class model_backend(HFTorchInferenceModel): enable=self.lazy_load, dematerialized_modules=False, ): - try: - if model_type == "gptj": - model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) - elif model_type == "gpt_neox": - model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) - elif model_type == "llama": - model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) - elif model_type == "opt": - model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) - elif model_type == "mpt": - model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) - elif model_type == "gpt_bigcode": - model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half() - else: - raise RuntimeError("Model not supported by Occam's GPTQ") - except: + if self.implementation == "occam": + try: + if model_type == "gptj": + model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) + elif model_type == "gpt_neox": + model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) + elif model_type == "llama": + model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) + elif model_type == "opt": + model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) + elif model_tseype == "mpt": + model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) + elif model_type == "gpt_bigcode": + model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half() + else: + raise RuntimeError("Model not supported by Occam's GPTQ") + except: + self.implementation = "AutoGPTQ" + if self.implementation == "AutoGPTQ": try: import auto_gptq from auto_gptq import AutoGPTQForCausalLM