mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
really really really sketchy breakmodel implementation
im gonna go lie down for an extended period of time
This commit is contained in:
@@ -82,6 +82,79 @@ def get_gptq_version(fpath):
|
|||||||
logger.warning(f"GPTQ model identified as v0, but v1={v1} and v2={v2}")
|
logger.warning(f"GPTQ model identified as v0, but v1={v1} and v2={v2}")
|
||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
|
def load_quant_offload_device_map(
|
||||||
|
load_quant_func, model, checkpoint, wbits, groupsize, device_map, offload_type=0, force_bias=False,
|
||||||
|
):
|
||||||
|
from gptq.offload import (
|
||||||
|
find_layers,
|
||||||
|
llama_offload_forward,
|
||||||
|
gptneox_offload_forward,
|
||||||
|
gptj_offload_forward,
|
||||||
|
opt_offload_forward,
|
||||||
|
bigcode_offload_forward
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
from transformers.models.opt.modeling_opt import OPTModel
|
||||||
|
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXModel
|
||||||
|
from transformers.models.gptj.modeling_gptj import GPTJModel
|
||||||
|
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeModel
|
||||||
|
model = load_quant_func(model, checkpoint, wbits, groupsize, force_bias=force_bias)
|
||||||
|
|
||||||
|
print(device_map)
|
||||||
|
|
||||||
|
m, layers, remaining = find_layers(model)
|
||||||
|
|
||||||
|
type(m).non_offload_forward = type(m).forward
|
||||||
|
|
||||||
|
# Hook offload_forward into found model
|
||||||
|
if type(m) == LlamaModel:
|
||||||
|
type(m).forward = llama_offload_forward
|
||||||
|
elif type(m) == GPTNeoXModel:
|
||||||
|
type(m).forward = gptneox_offload_forward
|
||||||
|
elif type(m) == GPTJModel:
|
||||||
|
type(m).forward = gptj_offload_forward
|
||||||
|
elif type(m) == OPTModel:
|
||||||
|
type(m).forward = opt_offload_forward
|
||||||
|
elif type(m) == GPTBigCodeModel:
|
||||||
|
type(m).forward = bigcode_offload_forward
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model type {type(m)} not supported by CPU offloader")
|
||||||
|
|
||||||
|
layers_done = len([1 for v in device_map.values() if v != "cpu"])
|
||||||
|
print("LDone", layers_done)
|
||||||
|
|
||||||
|
m.cpu_device = torch.device("cpu")
|
||||||
|
m.fast_offload = layers_done > len(layers) // 2
|
||||||
|
m.layer_count = len(layers)
|
||||||
|
m.cpu_layers = len(layers) - layers_done
|
||||||
|
m.gpu_layers = layers_done
|
||||||
|
m.offload_type = offload_type
|
||||||
|
# HACK
|
||||||
|
m.primary_gpu = list(device_map.values())[0]
|
||||||
|
|
||||||
|
if "layers" not in dir(m):
|
||||||
|
m.layers = layers
|
||||||
|
|
||||||
|
print(len(layers))
|
||||||
|
print(len(device_map))
|
||||||
|
|
||||||
|
print(m.primary_gpu)
|
||||||
|
for i in range(len(layers)):
|
||||||
|
dev = None
|
||||||
|
for key, device in device_map.items():
|
||||||
|
key = int(*[x for x in key.split(".") if x.isdecimal()])
|
||||||
|
if key == i:
|
||||||
|
dev = device
|
||||||
|
break
|
||||||
|
if dev is None:
|
||||||
|
raise ValueError
|
||||||
|
layers[key].to(dev, torch.float16, False)
|
||||||
|
|
||||||
|
for module in remaining:
|
||||||
|
module.to(m.primary_gpu)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class model_backend(HFTorchInferenceModel):
|
class model_backend(HFTorchInferenceModel):
|
||||||
def is_valid(self, model_name, model_path, menu_path):
|
def is_valid(self, model_name, model_path, menu_path):
|
||||||
@@ -166,7 +239,7 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
self.model.kai_model = self
|
self.model.kai_model = self
|
||||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
|
||||||
def _patch_quant(self) -> None:
|
def _patch_quant(self, device_map) -> None:
|
||||||
# QuantLinear loads on the CPU by default, using a lot of RAM! If we
|
# QuantLinear loads on the CPU by default, using a lot of RAM! If we
|
||||||
# load it to the same device that the weights are gonna be on, it
|
# load it to the same device that the weights are gonna be on, it
|
||||||
# mysteriously uses no additional VRAM
|
# mysteriously uses no additional VRAM
|
||||||
@@ -175,14 +248,54 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
from gptq import quant_v2
|
from gptq import quant_v2
|
||||||
from gptq import quant_v1
|
from gptq import quant_v1
|
||||||
|
|
||||||
def _ql_init_(self, *args, **kwargs):
|
def make_quant(module, names, bits, groupsize, name='', force_bias=False):
|
||||||
ret = type(self)._unpatched_init(self, *args, **kwargs)
|
if isinstance(module, quant_v3.QuantLinear):
|
||||||
self.to("cuda:0")
|
return
|
||||||
return ret
|
|
||||||
|
|
||||||
for quant_module in [quant_v3, quant_v2, quant_v1]:
|
for attr in dir(module):
|
||||||
quant_module.QuantLinear._unpatched_init = quant_module.QuantLinear.__init__
|
tmp = getattr(module, attr)
|
||||||
quant_module.QuantLinear.__init__ = _ql_init_
|
name1 = name + '.' + attr if name != '' else attr
|
||||||
|
if name1 in names:
|
||||||
|
parts = name1.split(".")
|
||||||
|
device = None
|
||||||
|
for i in reversed(range(len(parts))):
|
||||||
|
maybe_key = ".".join(parts[:i])
|
||||||
|
if maybe_key in device_map:
|
||||||
|
device = device_map[maybe_key]
|
||||||
|
break
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
print(name1)
|
||||||
|
print(device_map)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
print("[ql]", name1, device)
|
||||||
|
delattr(module, attr)
|
||||||
|
|
||||||
|
ql = quant_v3.QuantLinear(
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
tmp.in_features,
|
||||||
|
tmp.out_features,
|
||||||
|
force_bias or tmp.bias is not None
|
||||||
|
)
|
||||||
|
ql = ql.to(device)
|
||||||
|
|
||||||
|
setattr(module, attr, ql)
|
||||||
|
|
||||||
|
for name1, child in module.named_children():
|
||||||
|
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1, force_bias=force_bias)
|
||||||
|
|
||||||
|
quant_v3.make_quant = make_quant
|
||||||
|
|
||||||
|
# def _ql_init_(self, *args, **kwargs):
|
||||||
|
# ret = type(self)._unpatched_init(self, *args, **kwargs)
|
||||||
|
# self.to("cuda:0")
|
||||||
|
# return ret
|
||||||
|
|
||||||
|
# for quant_module in [quant_v3, quant_v2, quant_v1]:
|
||||||
|
# quant_module.QuantLinear._unpatched_init = quant_module.QuantLinear.__init__
|
||||||
|
# quant_module.QuantLinear.__init__ = _ql_init_
|
||||||
|
|
||||||
|
|
||||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||||
@@ -193,9 +306,12 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
from gptq.opt import load_quant as opt_load_quant
|
from gptq.opt import load_quant as opt_load_quant
|
||||||
from gptq.bigcode import load_quant as bigcode_load_quant
|
from gptq.bigcode import load_quant as bigcode_load_quant
|
||||||
from gptq.mpt import load_quant as mpt_load_quant
|
from gptq.mpt import load_quant as mpt_load_quant
|
||||||
from gptq.offload import load_quant_offload
|
|
||||||
|
|
||||||
self._patch_quant()
|
try:
|
||||||
|
import hf_bleeding_edge
|
||||||
|
from hf_bleeding_edge import AutoModelForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
gptq_model, gptq_bits, gptq_groupsize, gptq_file, gptq_version = load_model_gptq_settings(location)
|
gptq_model, gptq_bits, gptq_groupsize, gptq_file, gptq_version = load_model_gptq_settings(location)
|
||||||
v2_bias = False
|
v2_bias = False
|
||||||
@@ -208,22 +324,43 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
|
|
||||||
logger.info(f"Using GPTQ file: {gptq_file}, {gptq_bits}-bit model, type {model_type}, version {gptq_version}{' (with bias)' if v2_bias else ''}, groupsize {gptq_groupsize}")
|
logger.info(f"Using GPTQ file: {gptq_file}, {gptq_bits}-bit model, type {model_type}, version {gptq_version}{' (with bias)' if v2_bias else ''}, groupsize {gptq_groupsize}")
|
||||||
|
|
||||||
|
device_map = {}
|
||||||
|
|
||||||
|
if self.lazy_load:
|
||||||
|
with lazy_loader.use_lazy_load(dematerialized_modules=True):
|
||||||
|
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||||
|
if utils.args.cpu:
|
||||||
|
device_map = {name: "cpu" for name in utils.layers_module_names}
|
||||||
|
for name in utils.get_missing_module_names(
|
||||||
|
metamodel, list(device_map.keys())
|
||||||
|
):
|
||||||
|
device_map[name] = "cpu"
|
||||||
|
else:
|
||||||
|
device_map = self.breakmodel_config.get_device_map(
|
||||||
|
metamodel
|
||||||
|
)
|
||||||
|
|
||||||
|
self._patch_quant(device_map)
|
||||||
|
|
||||||
with lazy_loader.use_lazy_load(
|
with lazy_loader.use_lazy_load(
|
||||||
enable=self.lazy_load,
|
enable=self.lazy_load,
|
||||||
dematerialized_modules=False,
|
dematerialized_modules=False,
|
||||||
):
|
):
|
||||||
if model_type == "gptj":
|
if model_type == "gptj":
|
||||||
model = load_quant_offload(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
elif model_type == "gpt_neox":
|
elif model_type == "gpt_neox":
|
||||||
model = load_quant_offload(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
elif model_type == "llama":
|
elif model_type == "llama":
|
||||||
model = load_quant_offload(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
print("YE LAMA")
|
||||||
|
|
||||||
|
# model = llama_load_quant(location, gptq_file, gptq_bits, gptq_groupsize, force_bias=v2_bias)
|
||||||
|
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
elif model_type == "opt":
|
elif model_type == "opt":
|
||||||
model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
elif model_type == "mpt":
|
elif model_type == "mpt":
|
||||||
model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
elif model_type == "gpt_bigcode":
|
elif model_type == "gpt_bigcode":
|
||||||
model = load_quant_offload(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias).half()
|
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import auto_gptq
|
import auto_gptq
|
||||||
@@ -231,12 +368,6 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")
|
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")
|
||||||
|
|
||||||
try:
|
|
||||||
import hf_bleeding_edge
|
|
||||||
from hf_bleeding_edge import AutoModelForCausalLM
|
|
||||||
except ImportError:
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
# Monkey patch in hf_bleeding_edge to avoid having to trust remote code
|
# Monkey patch in hf_bleeding_edge to avoid having to trust remote code
|
||||||
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
|
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||||
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
|
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||||
|
Reference in New Issue
Block a user