Better AutoGPTQ fallback

This commit is contained in:
Henk
2023-08-10 18:10:48 +02:00
parent f2d7ef3aca
commit 9c7ebe3b04

View File

@@ -169,6 +169,7 @@ class model_backend(HFTorchInferenceModel):
self.init_model_config() self.init_model_config()
self.lazy_load = True self.lazy_load = True
self.implementation = "occam"
gpulayers = self.breakmodel_config.gpu_blocks gpulayers = self.breakmodel_config.gpu_blocks
@@ -323,22 +324,25 @@ class model_backend(HFTorchInferenceModel):
enable=self.lazy_load, enable=self.lazy_load,
dematerialized_modules=False, dematerialized_modules=False,
): ):
try: if self.implementation == "occam":
if model_type == "gptj": try:
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) if model_type == "gptj":
elif model_type == "gpt_neox": model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) elif model_type == "gpt_neox":
elif model_type == "llama": model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) elif model_type == "llama":
elif model_type == "opt": model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) elif model_type == "opt":
elif model_type == "mpt": model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias) elif model_tseype == "mpt":
elif model_type == "gpt_bigcode": model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half() elif model_type == "gpt_bigcode":
else: model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
raise RuntimeError("Model not supported by Occam's GPTQ") else:
except: raise RuntimeError("Model not supported by Occam's GPTQ")
except:
self.implementation = "AutoGPTQ"
if self.implementation == "AutoGPTQ":
try: try:
import auto_gptq import auto_gptq
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM