Basic GPTQ Downloader

This commit is contained in:
Henk
2023-08-19 13:02:50 +02:00
parent 029e8736c0
commit 13b68c67d1

View File

@@ -228,12 +228,14 @@ class model_backend(HFTorchInferenceModel):
logger.warning(f"Gave up on lazy loading due to {e}")
self.lazy_load = False
if self.get_local_model_path():
# Model is stored locally, load it.
self.model = self._get_model(self.get_local_model_path())
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
else:
raise NotImplementedError("GPTQ Model downloading not implemented")
if not self.get_local_model_path():
print(self.get_local_model_path())
from huggingface_hub import snapshot_download
target_dir = "models/" + self.model_name.replace("/", "_")
snapshot_download(self.model_name, local_dir=target_dir, local_dir_use_symlinks=False, cache_dir="cache/")
self.model = self._get_model(self.get_local_model_path())
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
if (
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
@@ -350,23 +352,20 @@ class model_backend(HFTorchInferenceModel):
dematerialized_modules=False,
):
if self.implementation == "occam":
try:
if model_type == "gptj":
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_neox":
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "llama":
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "opt":
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_tseype == "mpt":
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_bigcode":
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
else:
raise RuntimeError("Model not supported by Occam's GPTQ")
except:
self.implementation = "AutoGPTQ"
if model_type == "gptj":
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_neox":
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "llama":
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "opt":
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_tseype == "mpt":
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
elif model_type == "gpt_bigcode":
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
else:
raise RuntimeError("Model not supported by Occam's GPTQ")
if self.implementation == "AutoGPTQ":
try:
import auto_gptq