mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add 4bit safetensor support, improve loading code
This commit is contained in:
78
aiserver.py
78
aiserver.py
@@ -90,6 +90,7 @@ global tpu_mtj_backend
|
|||||||
|
|
||||||
# 4-bit dependencies
|
# 4-bit dependencies
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import glob
|
||||||
sys.path.insert(0, os.path.abspath(Path("repos/gptq")))
|
sys.path.insert(0, os.path.abspath(Path("repos/gptq")))
|
||||||
from gptj import load_quant as gptj_load_quant
|
from gptj import load_quant as gptj_load_quant
|
||||||
from gptneox import load_quant as gptneox_load_quant
|
from gptneox import load_quant as gptneox_load_quant
|
||||||
@@ -2659,6 +2660,50 @@ def unload_model():
|
|||||||
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
|
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_4bit_load(modelpath):
|
||||||
|
paths_4bit = ["4bit.pt", "4bit.safetensors"]
|
||||||
|
paths_4bit_old = ["4bit-old.pt", "4bit-old.safetensors"]
|
||||||
|
result = False
|
||||||
|
for p in paths_4bit:
|
||||||
|
p = os.path.join(modelpath, p)
|
||||||
|
if os.path.isfile(p):
|
||||||
|
result = p
|
||||||
|
break
|
||||||
|
|
||||||
|
global monkey_patched_4bit
|
||||||
|
|
||||||
|
# Monkey-patch in old-format pt-file support
|
||||||
|
if not result:
|
||||||
|
print(f"4-bit file {path_4bit} not found, falling back to {path_4bit_old}")
|
||||||
|
for p in paths_4bit_old:
|
||||||
|
p = os.path.join(modelpath, p)
|
||||||
|
if os.path.isfile(p):
|
||||||
|
result = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
print(f"4-bit old-format file {path_4bit} not found, loading failed")
|
||||||
|
raise RuntimeError(f"4-bit load failed. PT-File not found at {path_4bit}")
|
||||||
|
|
||||||
|
import llama, opt, gptneox, gptj, old_quant
|
||||||
|
llama.make_quant = old_quant.old_make_quant
|
||||||
|
opt.make_quant = old_quant.old_make_quant
|
||||||
|
gptneox.make_quant = old_quant.old_make_quant
|
||||||
|
gptj.make_quant = old_quant.old_make_quant
|
||||||
|
monkey_patched_4bit = True
|
||||||
|
elif monkey_patched_4bit:
|
||||||
|
# Undo monkey patch
|
||||||
|
print("Undoing 4-bit old format monkey patch")
|
||||||
|
import llama, opt, gptneox, gptj, quant
|
||||||
|
llama.make_quant = quant.make_quant
|
||||||
|
opt.make_quant = quant.make_quant
|
||||||
|
gptneox.make_quant = quant.make_quant
|
||||||
|
gptj.make_quant = quant.make_quant
|
||||||
|
monkey_patched_4bit = False
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model="", use_breakmodel_args=False, breakmodel_args_default_to_cpu=False, url=None, use_8_bit=False, use_4_bit=False):
|
def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model="", use_breakmodel_args=False, breakmodel_args_default_to_cpu=False, url=None, use_8_bit=False, use_4_bit=False):
|
||||||
global model
|
global model
|
||||||
global generator
|
global generator
|
||||||
@@ -3127,36 +3172,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
if(os.path.isdir(koboldai_vars.custmodpth)):
|
if(os.path.isdir(koboldai_vars.custmodpth)):
|
||||||
|
|
||||||
if use_4_bit:
|
if use_4_bit:
|
||||||
path_4bit = os.path.join(koboldai_vars.custmodpth, "4bit.pt")
|
path_4bit = prepare_4bit_load(koboldai_vars.custmodpth)
|
||||||
path_4bit_old = os.path.join(koboldai_vars.custmodpth, "4bit-old.pt")
|
|
||||||
|
|
||||||
global monkey_patched_4bit
|
|
||||||
|
|
||||||
# Monkey-patch in old-format pt-file support
|
|
||||||
if not os.path.isfile(path_4bit):
|
|
||||||
print(f"4-bit file {path_4bit} not found, falling back to {path_4bit_old}")
|
|
||||||
path_4bit = path_4bit_old
|
|
||||||
|
|
||||||
import llama, opt, gptneox, gptj, old_quant
|
|
||||||
llama.make_quant = old_quant.old_make_quant
|
|
||||||
opt.make_quant = old_quant.old_make_quant
|
|
||||||
gptneox.make_quant = old_quant.old_make_quant
|
|
||||||
gptj.make_quant = old_quant.old_make_quant
|
|
||||||
monkey_patched_4bit = True
|
|
||||||
elif monkey_patched_4bit:
|
|
||||||
# Undo monkey patch
|
|
||||||
print("Undoing 4-bit old format monkey patch")
|
|
||||||
import llama, opt, gptneox, gptj, quant
|
|
||||||
llama.make_quant = quant.make_quant
|
|
||||||
opt.make_quant = quant.make_quant
|
|
||||||
gptneox.make_quant = quant.make_quant
|
|
||||||
gptj.make_quant = quant.make_quant
|
|
||||||
monkey_patched_4bit = False
|
|
||||||
|
|
||||||
|
|
||||||
if not os.path.isfile(path_4bit):
|
|
||||||
print(f"4-bit old-format file {path_4bit} not found, loading failed")
|
|
||||||
raise RuntimeError(f"4-bit load failed. PT-File not found at {path_4bit}")
|
|
||||||
|
|
||||||
print(f"Trying to load {koboldai_vars.model_type} model in 4-bit")
|
print(f"Trying to load {koboldai_vars.model_type} model in 4-bit")
|
||||||
koboldai_vars.breakmodel = False
|
koboldai_vars.breakmodel = False
|
||||||
@@ -3171,7 +3187,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
|||||||
model = llama_load_quant(koboldai_vars.custmodpth, path_4bit, 4, -1)
|
model = llama_load_quant(koboldai_vars.custmodpth, path_4bit, 4, -1)
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(koboldai_vars.custmodpth)
|
tokenizer = LlamaTokenizer.from_pretrained(koboldai_vars.custmodpth)
|
||||||
elif koboldai_vars.model_type == "opt":
|
elif koboldai_vars.model_type == "opt":
|
||||||
model = opt_load_quant(koboldai_vars.custmodpth, path_4bit, 4)
|
model = opt_load_quant(koboldai_vars.custmodpth, path_4bit, 4, -1)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth)
|
tokenizer = AutoTokenizer.from_pretrained(koboldai_vars.custmodpth)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"4-bit load failed. Model type {koboldai_vars.model_type} not supported in 4-bit")
|
raise RuntimeError(f"4-bit load failed. Model type {koboldai_vars.model_type} not supported in 4-bit")
|
||||||
|
Reference in New Issue
Block a user