diff --git a/modeling/inference_models/basic_hf/class.py b/modeling/inference_models/basic_hf/class.py index ecbc55cc..9d4b643b 100644 --- a/modeling/inference_models/basic_hf/class.py +++ b/modeling/inference_models/basic_hf/class.py @@ -55,7 +55,7 @@ class model_backend(HFInferenceModel): self.init_model_config() self.model = AutoModelForCausalLM.from_pretrained( - self.get_local_model_path(), low_cpu_mem_usage=True + self.get_local_model_path(), low_cpu_mem_usage=True, device_map="auto" ) if self.usegpu: