Better AutoGPTQ fallback

This commit is contained in:
Henk
2023-08-10 18:10:48 +02:00
parent f2d7ef3aca
commit 9c7ebe3b04

View File

@@ -169,6 +169,7 @@ class model_backend(HFTorchInferenceModel):
self.init_model_config() self.init_model_config()
self.lazy_load = True self.lazy_load = True
self.implementation = "occam"
gpulayers = self.breakmodel_config.gpu_blocks gpulayers = self.breakmodel_config.gpu_blocks
@@ -323,6 +324,7 @@ class model_backend(HFTorchInferenceModel):
enable=self.lazy_load, enable=self.lazy_load,
dematerialized_modules=False, dematerialized_modules=False,
): ):
if self.implementation == "occam":
try: try:
if model_type == "gptj": if model_type == "gptj":
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
@@ -332,13 +334,15 @@ class model_backend(HFTorchInferenceModel):
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "opt": elif model_type == "opt":
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "mpt": elif model_tseype == "mpt":
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_bigcode": elif model_type == "gpt_bigcode":
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half() model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
else: else:
raise RuntimeError("Model not supported by Occam's GPTQ") raise RuntimeError("Model not supported by Occam's GPTQ")
except: except:
self.implementation = "AutoGPTQ"
if self.implementation == "AutoGPTQ":
try: try:
import auto_gptq import auto_gptq
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM