GPTQ improvements

This commit is contained in:
Henk
2023-08-19 14:45:45 +02:00
parent 13b68c67d1
commit d93631c889
3 changed files with 27 additions and 19 deletions

View File

@@ -155,7 +155,7 @@ class model_backend(HFTorchInferenceModel):
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
requested_parameters = super().get_requested_parameters(model_name, model_path, menu_path, parameters)
if model_name != 'customhuggingface' or "custom_model_name" in parameters:
if model_name != 'customgptq' or "custom_model_name" in parameters:
if os.path.exists("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_"))) and 'base_url' not in vars(self):
with open("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_")), "r") as f:
temp = json.load(f)
@@ -232,6 +232,7 @@ class model_backend(HFTorchInferenceModel):
print(self.get_local_model_path())
from huggingface_hub import snapshot_download
target_dir = "models/" + self.model_name.replace("/", "_")
print(self.model_name)
snapshot_download(self.model_name, local_dir=target_dir, local_dir_use_symlinks=False, cache_dir="cache/")
self.model = self._get_model(self.get_local_model_path())
@@ -352,20 +353,24 @@ class model_backend(HFTorchInferenceModel):
dematerialized_modules=False,
):
if self.implementation == "occam":
if model_type == "gptj":
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_neox":
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "llama":
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "opt":
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_tseype == "mpt":
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_bigcode":
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
else:
raise RuntimeError("Model not supported by Occam's GPTQ")
try:
if model_type == "gptj":
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_neox":
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "llama":
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "opt":
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_tseype == "mpt":
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_bigcode":
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
else:
raise RuntimeError("Model not supported by Occam's GPTQ")
except:
self.implementation = "AutoGPTQ"
if self.implementation == "AutoGPTQ":
try:
import auto_gptq
@@ -378,11 +383,13 @@ class model_backend(HFTorchInferenceModel):
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
autogptq_failed = False
try:
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"), device_map=device_map)
except:
autogptq_failed = True # Ugly hack to get it to free the VRAM of the last attempt like we do above, better suggestions welcome - Henk
if autogptq_failed:
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"), device_map=device_map, disable_exllama=True)
# Patch in embeddings function
def get_input_embeddings(self):
return self.model.get_input_embeddings()