Use device_map="auto"

This commit is contained in:
somebody
2023-07-12 17:27:48 -05:00
parent 60473d4c23
commit d17ce8461d

View File

@@ -55,7 +55,7 @@ class model_backend(HFInferenceModel):
self.init_model_config() self.init_model_config()
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.get_local_model_path(), low_cpu_mem_usage=True self.get_local_model_path(), low_cpu_mem_usage=True, device_map="auto"
) )
if self.usegpu: if self.usegpu: