This commit is contained in:
somebody
2023-07-23 20:54:04 -05:00
parent 70d2da55e5
commit 1df03d9a27
3 changed files with 79 additions and 39 deletions

View File

@@ -89,6 +89,12 @@ class model_backend(HFTorchInferenceModel):
return bool(gptq_model) return bool(gptq_model)
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
try:
import hf_bleeding_edge
from hf_bleeding_edge import AutoModelForCausalLM
except ImportError:
from transformers import AutoModelForCausalLM
# Make model path the same as the model name to make this consistent # Make model path the same as the model name to make this consistent
# with the other loading method if it isn't a known model type. This # with the other loading method if it isn't a known model type. This
# code is not just a workaround for below, it is also used to make the # code is not just a workaround for below, it is also used to make the
@@ -98,7 +104,7 @@ class model_backend(HFTorchInferenceModel):
self.init_model_config() self.init_model_config()
self.lazy_load = False self.lazy_load = True
gpulayers = self.breakmodel_config.gpu_blocks gpulayers = self.breakmodel_config.gpu_blocks
@@ -181,11 +187,20 @@ class model_backend(HFTorchInferenceModel):
model_type = self.get_model_type() model_type = self.get_model_type()
logger.info(f"Using GPTQ file: {gptq_file}, {gptq_bits}-bit model, type {model_type}, version {gptq_version}{' (with bias)' if v2_bias else ''}, groupsize {gptq_groupsize}") logger.info(f"Using GPTQ file: {gptq_file}, {gptq_bits}-bit model, type {model_type}, version {gptq_version}{' (with bias)' if v2_bias else ''}, groupsize {gptq_groupsize}")
with lazy_loader.use_lazy_load(
enable=self.lazy_load,
dematerialized_modules=False,
):
print(self.lazy_load)
if model_type == "gptj": if model_type == "gptj":
model = load_quant_offload(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias) model = load_quant_offload(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
elif model_type == "gpt_neox": elif model_type == "gpt_neox":
model = load_quant_offload(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias) model = load_quant_offload(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
elif model_type == "llama": elif model_type == "llama":
print("LLLLLAAAMMMAA")
print(torch.load)
model = load_quant_offload(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias) model = load_quant_offload(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
elif model_type == "opt": elif model_type == "opt":
model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias) model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
@@ -210,6 +225,7 @@ class model_backend(HFTorchInferenceModel):
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors")) model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"))
# Patch in embeddings function # Patch in embeddings function

View File

@@ -358,16 +358,19 @@ def safetensors_load_tensor_independently(
) -> torch.Tensor: ) -> torch.Tensor:
"""A hacky way to load a tensor by itself and not mmap every single tensor """A hacky way to load a tensor by itself and not mmap every single tensor
or whatever is causing that big memory spike""" or whatever is causing that big memory spike"""
print("[ld]", tensor_key)
with safetensors.safe_open(checkpoint_file, framework="pt", device=device) as f: with safetensors.safe_open(checkpoint_file, framework="pt", device=device) as f:
return f.get_tensor(tensor_key) return f.get_tensor(tensor_key)
def patch_safetensors(callback): def patch_safetensors(callback):
print("Hi! We are patching safetensors")
# Safetensors load patch # Safetensors load patch
import transformers import transformers
def safetensors_load(checkpoint_file: str) -> dict: def safetensors_load(checkpoint_file: str) -> dict:
print("LOAD NOW", safetensors_load)
# Monkeypatch applied to safetensors.torch.load_file # Monkeypatch applied to safetensors.torch.load_file
if utils.koboldai_vars.hascuda: if utils.koboldai_vars.hascuda:
@@ -409,6 +412,7 @@ def patch_safetensors(callback):
return tensors return tensors
transformers.modeling_utils.safe_load_file = safetensors_load transformers.modeling_utils.safe_load_file = safetensors_load
safetensors.torch.load_file = safetensors_load
@contextlib.contextmanager @contextlib.contextmanager
@@ -520,6 +524,7 @@ def use_lazy_load(
old_torch_load = torch.load old_torch_load = torch.load
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
print("TORCHLOAD", f)
model_dict = old_torch_load( model_dict = old_torch_load(
f=f, f=f,
map_location=map_location, map_location=map_location,

View File

@@ -129,15 +129,34 @@ def patch_transformers_generation() -> None:
class LazyloadPatches: class LazyloadPatches:
class StateDictFacade(dict):
def __init__(self, state_dict):
self.update(state_dict)
def __getitem__(self, name):
return super().__getitem__(name).materialize(map_location="cuda:0")
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
torch_old_load_from_state_dict = torch.nn.Module._load_from_state_dict
def __enter__() -> None: def __enter__() -> None:
transformers.modeling_utils._load_state_dict_into_meta_model = ( transformers.modeling_utils._load_state_dict_into_meta_model = (
LazyloadPatches._load_state_dict_into_meta_model LazyloadPatches._load_state_dict_into_meta_model
) )
torch.nn.Module._load_from_state_dict = LazyloadPatches._torch_load_from_state_dict
# torch.nn.Module._load_from_state_dict = _agn
def __exit__(exc_type, exc_value, exc_traceback) -> None: def __exit__(exc_type, exc_value, exc_traceback) -> None:
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
torch.nn.Module._load_from_state_dict = LazyloadPatches.torch_old_load_from_state_dict
def _torch_load_from_state_dict(self, state_dict, *args, **kwargs):
return LazyloadPatches.torch_old_load_from_state_dict(
self,
LazyloadPatches.StateDictFacade(state_dict),
*args,
**kwargs
)
def _load_state_dict_into_meta_model( def _load_state_dict_into_meta_model(
model, model,