mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
GPTQ improvements
This commit is contained in:
@@ -242,7 +242,8 @@ model_menu = {
|
||||
"mainmenu": [
|
||||
MenuPath("Load a model from its directory", "NeoCustom"),
|
||||
MenuPath("Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom"),
|
||||
MenuModel("Load custom model from Hugging Face", "customhuggingface", ""),
|
||||
MenuModel("Load custom Pytorch model from Hugging Face", "customhuggingface", ""),
|
||||
MenuModel("Load custom GPTQ model from Hugging Face", "customgptq", "", model_backend="GPTQ"),
|
||||
MenuFolder("Instruct Models", "instructlist"),
|
||||
MenuFolder("Novel Models", "novellist"),
|
||||
MenuFolder("Chat Models", "chatlist"),
|
||||
|
@@ -155,7 +155,7 @@ class model_backend(HFTorchInferenceModel):
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
|
||||
requested_parameters = super().get_requested_parameters(model_name, model_path, menu_path, parameters)
|
||||
if model_name != 'customhuggingface' or "custom_model_name" in parameters:
|
||||
if model_name != 'customgptq' or "custom_model_name" in parameters:
|
||||
if os.path.exists("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_"))) and 'base_url' not in vars(self):
|
||||
with open("settings/{}.generic_hf_torch.model_backend.settings".format(model_name.replace("/", "_")), "r") as f:
|
||||
temp = json.load(f)
|
||||
@@ -232,6 +232,7 @@ class model_backend(HFTorchInferenceModel):
|
||||
print(self.get_local_model_path())
|
||||
from huggingface_hub import snapshot_download
|
||||
target_dir = "models/" + self.model_name.replace("/", "_")
|
||||
print(self.model_name)
|
||||
snapshot_download(self.model_name, local_dir=target_dir, local_dir_use_symlinks=False, cache_dir="cache/")
|
||||
|
||||
self.model = self._get_model(self.get_local_model_path())
|
||||
@@ -352,20 +353,24 @@ class model_backend(HFTorchInferenceModel):
|
||||
dematerialized_modules=False,
|
||||
):
|
||||
if self.implementation == "occam":
|
||||
if model_type == "gptj":
|
||||
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "gpt_neox":
|
||||
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "llama":
|
||||
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "opt":
|
||||
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_tseype == "mpt":
|
||||
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "gpt_bigcode":
|
||||
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
|
||||
else:
|
||||
raise RuntimeError("Model not supported by Occam's GPTQ")
|
||||
try:
|
||||
if model_type == "gptj":
|
||||
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "gpt_neox":
|
||||
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "llama":
|
||||
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "opt":
|
||||
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_tseype == "mpt":
|
||||
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||
elif model_type == "gpt_bigcode":
|
||||
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
|
||||
else:
|
||||
raise RuntimeError("Model not supported by Occam's GPTQ")
|
||||
except:
|
||||
self.implementation = "AutoGPTQ"
|
||||
|
||||
if self.implementation == "AutoGPTQ":
|
||||
try:
|
||||
import auto_gptq
|
||||
@@ -378,11 +383,13 @@ class model_backend(HFTorchInferenceModel):
|
||||
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
|
||||
|
||||
autogptq_failed = False
|
||||
try:
|
||||
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"), device_map=device_map)
|
||||
except:
|
||||
autogptq_failed = True # Ugly hack to get it to free the VRAM of the last attempt like we do above, better suggestions welcome - Henk
|
||||
if autogptq_failed:
|
||||
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"), device_map=device_map, disable_exllama=True)
|
||||
|
||||
# Patch in embeddings function
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
@@ -47,7 +47,7 @@ class HFInferenceModel(InferenceModel):
|
||||
requested_parameters = []
|
||||
if not self.hf_torch:
|
||||
return []
|
||||
if model_name == 'customhuggingface':
|
||||
if model_name in ('customhuggingface', 'customgptq'):
|
||||
requested_parameters.append({
|
||||
"uitype": "text",
|
||||
"unit": "text",
|
||||
@@ -61,7 +61,7 @@ class HFInferenceModel(InferenceModel):
|
||||
"extra_classes": ""
|
||||
})
|
||||
|
||||
if model_name != 'customhuggingface' or "custom_model_name" in parameters:
|
||||
if model_name not in ('customhuggingface', 'customgptq') or "custom_model_name" in parameters:
|
||||
model_name = parameters["custom_model_name"] if "custom_model_name" in parameters and parameters["custom_model_name"] != "" else model_name
|
||||
if model_path is not None and os.path.exists(model_path):
|
||||
self.model_config = AutoConfig.from_pretrained(model_path)
|
||||
|
Reference in New Issue
Block a user