Always use offloader script, because it speeds up multi gpu

This commit is contained in:
0cc4m
2023-04-30 18:17:43 +02:00
parent d8949042d4
commit 20a5587d66

View File

@@ -333,25 +333,13 @@ class HFTorch4BitInferenceModel(HFTorchInferenceModel):
print(f"Trying to load {utils.koboldai_vars.model_type} model in 4-bit")
if utils.koboldai_vars.model_type == "gptj":
if self.offload_4bit:
model = load_quant_offload(gptj_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
else:
model = gptj_load_quant(utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize)
model = load_quant_offload(gptj_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
elif utils.koboldai_vars.model_type == "gpt_neox":
if self.offload_4bit:
model = load_quant_offload(gptneox_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
else:
model = gptneox_load_quant(utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize)
model = load_quant_offload(gptneox_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
elif utils.koboldai_vars.model_type == "llama":
if self.offload_4bit:
model = load_quant_offload(llama_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
else:
model = llama_load_quant(utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize)
model = load_quant_offload(llama_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
elif utils.koboldai_vars.model_type == "opt":
if self.offload_4bit:
model = load_quant_offload(opt_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
else:
model = opt_load_quant(utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize)
model = load_quant_offload(opt_load_quant, utils.koboldai_vars.custmodpth, path_4bit, 4, groupsize, self.gpu_layers_list)
else:
raise RuntimeError(f"4-bit load failed. Model type {utils.koboldai_vars.model_type} not supported in 4-bit")