diff --git a/modeling/inference_models/gptq_hf_torch/class.py b/modeling/inference_models/gptq_hf_torch/class.py index b44fcd7a..eb3d2475 100644 --- a/modeling/inference_models/gptq_hf_torch/class.py +++ b/modeling/inference_models/gptq_hf_torch/class.py @@ -350,7 +350,7 @@ class model_backend(HFTorchInferenceModel): auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM - model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors")) + model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"), device_map=device_map) # Patch in embeddings function def get_input_embeddings(self):