From 9ec50c997280856dee810a74e18cd11fd5304228 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 6 May 2023 21:58:23 +0200 Subject: [PATCH] Fix 4-bit mpt --- modeling/inference_models/hf_torch_4bit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modeling/inference_models/hf_torch_4bit.py b/modeling/inference_models/hf_torch_4bit.py index 959d6258..8aaddcc1 100644 --- a/modeling/inference_models/hf_torch_4bit.py +++ b/modeling/inference_models/hf_torch_4bit.py @@ -34,6 +34,7 @@ from gptq.gptj import load_quant as gptj_load_quant from gptq.gptneox import load_quant as gptneox_load_quant from gptq.llama import load_quant as llama_load_quant from gptq.opt import load_quant as opt_load_quant +from gptq.mpt import load_quant as mpt_load_quant from gptq.offload import load_quant_offload @@ -369,6 +370,8 @@ class HFTorch4BitInferenceModel(HFTorchInferenceModel): model = load_quant_offload(llama_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list) elif utils.koboldai_vars.model_type == "opt": model = load_quant_offload(opt_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list) + elif utils.koboldai_vars.model_type == "mpt": + model = load_quant_offload(mpt_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list) else: raise RuntimeError(f"4-bit load failed. Model type {utils.koboldai_vars.model_type} not supported in 4-bit")