mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
@@ -7,7 +7,7 @@ import torch
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from typing import Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import modeling.lazy_loader as lazy_loader
|
import modeling.lazy_loader as lazy_loader
|
||||||
@@ -82,6 +82,71 @@ def get_gptq_version(fpath):
|
|||||||
logger.warning(f"GPTQ model identified as v0, but v1={v1} and v2={v2}")
|
logger.warning(f"GPTQ model identified as v0, but v1={v1} and v2={v2}")
|
||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
|
def load_quant_offload_device_map(
|
||||||
|
load_quant_func, model, checkpoint, wbits, groupsize, device_map, offload_type=0, force_bias=False,
|
||||||
|
):
|
||||||
|
from gptq.offload import (
|
||||||
|
find_layers,
|
||||||
|
llama_offload_forward,
|
||||||
|
gptneox_offload_forward,
|
||||||
|
gptj_offload_forward,
|
||||||
|
opt_offload_forward,
|
||||||
|
bigcode_offload_forward
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
from transformers.models.opt.modeling_opt import OPTModel
|
||||||
|
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXModel
|
||||||
|
from transformers.models.gptj.modeling_gptj import GPTJModel
|
||||||
|
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeModel
|
||||||
|
model = load_quant_func(model, checkpoint, wbits, groupsize, force_bias=force_bias)
|
||||||
|
|
||||||
|
m, layers, remaining = find_layers(model)
|
||||||
|
type(m).non_offload_forward = type(m).forward
|
||||||
|
|
||||||
|
# Hook offload_forward into found model
|
||||||
|
if type(m) == LlamaModel:
|
||||||
|
type(m).forward = llama_offload_forward
|
||||||
|
elif type(m) == GPTNeoXModel:
|
||||||
|
type(m).forward = gptneox_offload_forward
|
||||||
|
elif type(m) == GPTJModel:
|
||||||
|
type(m).forward = gptj_offload_forward
|
||||||
|
elif type(m) == OPTModel:
|
||||||
|
type(m).forward = opt_offload_forward
|
||||||
|
elif type(m) == GPTBigCodeModel:
|
||||||
|
type(m).forward = bigcode_offload_forward
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model type {type(m)} not supported by CPU offloader")
|
||||||
|
|
||||||
|
layers_done = len([1 for v in device_map.values() if v != "cpu"])
|
||||||
|
|
||||||
|
m.cpu_device = torch.device("cpu")
|
||||||
|
m.fast_offload = layers_done > len(layers) // 2
|
||||||
|
m.layer_count = len(layers)
|
||||||
|
m.cpu_layers = len(layers) - layers_done
|
||||||
|
m.gpu_layers = layers_done
|
||||||
|
m.offload_type = offload_type
|
||||||
|
# HACK
|
||||||
|
m.primary_gpu = list(device_map.values())[0]
|
||||||
|
|
||||||
|
if "layers" not in dir(m):
|
||||||
|
m.layers = layers
|
||||||
|
|
||||||
|
for i in range(len(layers)):
|
||||||
|
dev = None
|
||||||
|
for key, device in device_map.items():
|
||||||
|
key = int(*[x for x in key.split(".") if x.isdecimal()])
|
||||||
|
if key == i:
|
||||||
|
dev = device
|
||||||
|
break
|
||||||
|
if dev is None:
|
||||||
|
raise ValueError
|
||||||
|
layers[key].to(dev, torch.float16, False)
|
||||||
|
|
||||||
|
for module in remaining:
|
||||||
|
module.to(m.primary_gpu)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class model_backend(HFTorchInferenceModel):
|
class model_backend(HFTorchInferenceModel):
|
||||||
def is_valid(self, model_name, model_path, menu_path):
|
def is_valid(self, model_name, model_path, menu_path):
|
||||||
@@ -89,6 +154,11 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
return bool(gptq_model)
|
return bool(gptq_model)
|
||||||
|
|
||||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
try:
|
||||||
|
from hf_bleeding_edge import AutoModelForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
# Make model path the same as the model name to make this consistent
|
# Make model path the same as the model name to make this consistent
|
||||||
# with the other loading method if it isn't a known model type. This
|
# with the other loading method if it isn't a known model type. This
|
||||||
# code is not just a workaround for below, it is also used to make the
|
# code is not just a workaround for below, it is also used to make the
|
||||||
@@ -98,7 +168,7 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
|
|
||||||
self.init_model_config()
|
self.init_model_config()
|
||||||
|
|
||||||
self.lazy_load = False
|
self.lazy_load = True
|
||||||
|
|
||||||
gpulayers = self.breakmodel_config.gpu_blocks
|
gpulayers = self.breakmodel_config.gpu_blocks
|
||||||
|
|
||||||
@@ -107,10 +177,6 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
self.gpu_layers_list = [utils.num_layers(self.model_config)]
|
self.gpu_layers_list = [utils.num_layers(self.model_config)]
|
||||||
|
|
||||||
tf_kwargs = {
|
|
||||||
"low_cpu_mem_usage": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# If we're using torch_lazy_loader, we need to get breakmodel config
|
# If we're using torch_lazy_loader, we need to get breakmodel config
|
||||||
# early so that it knows where to load the individual model tensors
|
# early so that it knows where to load the individual model tensors
|
||||||
logger.debug("lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format(self.lazy_load, utils.koboldai_vars.hascuda, self.breakmodel, self.nobreakmodel))
|
logger.debug("lazy_load: {} hascuda: {} breakmodel: {} nobreakmode: {}".format(self.lazy_load, utils.koboldai_vars.hascuda, self.breakmodel, self.nobreakmodel))
|
||||||
@@ -123,9 +189,6 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
self.breakmodel_device_config(self.model_config)
|
self.breakmodel_device_config(self.model_config)
|
||||||
|
|
||||||
if self.lazy_load:
|
if self.lazy_load:
|
||||||
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
|
||||||
tf_kwargs.pop("low_cpu_mem_usage", None)
|
|
||||||
|
|
||||||
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||||
with lazy_loader.use_lazy_load(dematerialized_modules=True):
|
with lazy_loader.use_lazy_load(dematerialized_modules=True):
|
||||||
try:
|
try:
|
||||||
@@ -141,7 +204,7 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
|
|
||||||
if self.get_local_model_path():
|
if self.get_local_model_path():
|
||||||
# Model is stored locally, load it.
|
# Model is stored locally, load it.
|
||||||
self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
|
self.model = self._get_model(self.get_local_model_path())
|
||||||
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("GPTQ Model downloading not implemented")
|
raise NotImplementedError("GPTQ Model downloading not implemented")
|
||||||
@@ -161,7 +224,58 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
self.model.kai_model = self
|
self.model.kai_model = self
|
||||||
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
|
||||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
def _patch_quant(self, device_map, quant_module) -> None:
|
||||||
|
def make_quant(module, names, bits, groupsize, name='', force_bias=False, **kwargs):
|
||||||
|
if isinstance(module, quant_module.QuantLinear):
|
||||||
|
return
|
||||||
|
|
||||||
|
for attr in dir(module):
|
||||||
|
tmp = getattr(module, attr)
|
||||||
|
name1 = name + '.' + attr if name != '' else attr
|
||||||
|
if name1 in names:
|
||||||
|
parts = name1.split(".")
|
||||||
|
device = None
|
||||||
|
for i in reversed(range(len(parts))):
|
||||||
|
maybe_key = ".".join(parts[:i])
|
||||||
|
if maybe_key in device_map:
|
||||||
|
device = device_map[maybe_key]
|
||||||
|
break
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
raise ValueError(f"No device for {name1}")
|
||||||
|
|
||||||
|
delattr(module, attr)
|
||||||
|
|
||||||
|
ql = quant_module.QuantLinear(
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
tmp.in_features,
|
||||||
|
tmp.out_features,
|
||||||
|
force_bias or tmp.bias is not None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
ql = ql.to(device)
|
||||||
|
|
||||||
|
setattr(module, attr, ql)
|
||||||
|
|
||||||
|
for name1, child in module.named_children():
|
||||||
|
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1, force_bias=force_bias)
|
||||||
|
|
||||||
|
quant_module.make_quant = make_quant
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_quants(self, device_map) -> None:
|
||||||
|
# Load QuantLinears on the device corresponding to the device map
|
||||||
|
|
||||||
|
from gptq import quant_v3
|
||||||
|
from gptq import quant_v2
|
||||||
|
from gptq import quant_v1
|
||||||
|
|
||||||
|
for quant_module in [quant_v3, quant_v2, quant_v1]:
|
||||||
|
self._patch_quant(device_map, quant_module)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model(self, location: str):
|
||||||
import gptq
|
import gptq
|
||||||
from gptq.gptj import load_quant as gptj_load_quant
|
from gptq.gptj import load_quant as gptj_load_quant
|
||||||
from gptq.gptneox import load_quant as gptneox_load_quant
|
from gptq.gptneox import load_quant as gptneox_load_quant
|
||||||
@@ -169,7 +283,12 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
from gptq.opt import load_quant as opt_load_quant
|
from gptq.opt import load_quant as opt_load_quant
|
||||||
from gptq.bigcode import load_quant as bigcode_load_quant
|
from gptq.bigcode import load_quant as bigcode_load_quant
|
||||||
from gptq.mpt import load_quant as mpt_load_quant
|
from gptq.mpt import load_quant as mpt_load_quant
|
||||||
from gptq.offload import load_quant_offload
|
|
||||||
|
try:
|
||||||
|
import hf_bleeding_edge
|
||||||
|
from hf_bleeding_edge import AutoModelForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
gptq_model, gptq_bits, gptq_groupsize, gptq_file, gptq_version = load_model_gptq_settings(location)
|
gptq_model, gptq_bits, gptq_groupsize, gptq_file, gptq_version = load_model_gptq_settings(location)
|
||||||
v2_bias = False
|
v2_bias = False
|
||||||
@@ -181,50 +300,68 @@ class model_backend(HFTorchInferenceModel):
|
|||||||
model_type = self.get_model_type()
|
model_type = self.get_model_type()
|
||||||
|
|
||||||
logger.info(f"Using GPTQ file: {gptq_file}, {gptq_bits}-bit model, type {model_type}, version {gptq_version}{' (with bias)' if v2_bias else ''}, groupsize {gptq_groupsize}")
|
logger.info(f"Using GPTQ file: {gptq_file}, {gptq_bits}-bit model, type {model_type}, version {gptq_version}{' (with bias)' if v2_bias else ''}, groupsize {gptq_groupsize}")
|
||||||
if model_type == "gptj":
|
|
||||||
model = load_quant_offload(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
|
||||||
elif model_type == "gpt_neox":
|
|
||||||
model = load_quant_offload(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
|
||||||
elif model_type == "llama":
|
|
||||||
model = load_quant_offload(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
|
||||||
elif model_type == "opt":
|
|
||||||
model = load_quant_offload(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
|
||||||
elif model_type == "mpt":
|
|
||||||
model = load_quant_offload(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias)
|
|
||||||
elif model_type == "gpt_bigcode":
|
|
||||||
model = load_quant_offload(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, self.gpu_layers_list, force_bias=v2_bias).half()
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import auto_gptq
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")
|
|
||||||
|
|
||||||
try:
|
device_map = {}
|
||||||
import hf_bleeding_edge
|
|
||||||
from hf_bleeding_edge import AutoModelForCausalLM
|
|
||||||
except ImportError:
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
# Monkey patch in hf_bleeding_edge to avoid having to trust remote code
|
if self.lazy_load:
|
||||||
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
|
with lazy_loader.use_lazy_load(dematerialized_modules=True):
|
||||||
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
|
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||||
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
|
if utils.args.cpu:
|
||||||
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"))
|
device_map = {name: "cpu" for name in utils.layers_module_names}
|
||||||
|
for name in utils.get_missing_module_names(
|
||||||
|
metamodel, list(device_map.keys())
|
||||||
|
):
|
||||||
|
device_map[name] = "cpu"
|
||||||
|
else:
|
||||||
|
device_map = self.breakmodel_config.get_device_map(
|
||||||
|
metamodel
|
||||||
|
)
|
||||||
|
|
||||||
# Patch in embeddings function
|
self._patch_quants(device_map)
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.model.get_input_embeddings()
|
|
||||||
|
|
||||||
type(model).get_input_embeddings = get_input_embeddings
|
with lazy_loader.use_lazy_load(
|
||||||
|
enable=self.lazy_load,
|
||||||
|
dematerialized_modules=False,
|
||||||
|
):
|
||||||
|
if model_type == "gptj":
|
||||||
|
model = load_quant_offload_device_map(gptj_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
|
elif model_type == "gpt_neox":
|
||||||
|
model = load_quant_offload_device_map(gptneox_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
|
elif model_type == "llama":
|
||||||
|
model = load_quant_offload_device_map(llama_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
|
elif model_type == "opt":
|
||||||
|
model = load_quant_offload_device_map(opt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
|
elif model_type == "mpt":
|
||||||
|
model = load_quant_offload_device_map(mpt_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias)
|
||||||
|
elif model_type == "gpt_bigcode":
|
||||||
|
model = load_quant_offload_device_map(bigcode_load_quant, location, gptq_file, gptq_bits, gptq_groupsize, device_map, force_bias=v2_bias).half()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import auto_gptq
|
||||||
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(f"4-bit load failed. Model type {model_type} not supported in 4-bit")
|
||||||
|
|
||||||
# Patch in args support..
|
# Monkey patch in hf_bleeding_edge to avoid having to trust remote code
|
||||||
def generate(self, *args, **kwargs):
|
auto_gptq.modeling._utils.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||||
"""shortcut for model.generate"""
|
auto_gptq.modeling._base.AutoConfig = hf_bleeding_edge.AutoConfig
|
||||||
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
|
auto_gptq.modeling._base.AutoModelForCausalLM = hf_bleeding_edge.AutoModelForCausalLM
|
||||||
return self.model.generate(*args, **kwargs)
|
|
||||||
|
|
||||||
type(model).generate = generate
|
model = AutoGPTQForCausalLM.from_quantized(location, model_basename=Path(gptq_file).stem, use_safetensors=gptq_file.endswith(".safetensors"))
|
||||||
|
|
||||||
|
# Patch in embeddings function
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
type(model).get_input_embeddings = get_input_embeddings
|
||||||
|
|
||||||
|
# Patch in args support..
|
||||||
|
def generate(self, *args, **kwargs):
|
||||||
|
"""shortcut for model.generate"""
|
||||||
|
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
|
||||||
|
return self.model.generate(*args, **kwargs)
|
||||||
|
|
||||||
|
type(model).generate = generate
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@@ -176,9 +176,6 @@ class TorchLazyTensor(LazyTensor):
|
|||||||
CheckpointChunkCache.key = self.key
|
CheckpointChunkCache.key = self.key
|
||||||
ziproot = checkpoint.namelist()[0].split("/")[0]
|
ziproot = checkpoint.namelist()[0].split("/")[0]
|
||||||
CheckpointChunkCache.handle = checkpoint.open(f"{ziproot}/data/{self.key}", "r")
|
CheckpointChunkCache.handle = checkpoint.open(f"{ziproot}/data/{self.key}", "r")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Cache hit. Hip hip hooray! :^)
|
# Cache hit. Hip hip hooray! :^)
|
||||||
# print(".", end="", flush=True)
|
# print(".", end="", flush=True)
|
||||||
@@ -318,7 +315,6 @@ class _LazyUnpickler(RestrictedUnpickler):
|
|||||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# print(args, kwargs)
|
|
||||||
self.lazy_loaded_storages = {}
|
self.lazy_loaded_storages = {}
|
||||||
return super().__init__(*args, **kwargs)
|
return super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@@ -376,7 +372,7 @@ def patch_safetensors(callback):
|
|||||||
# (70 tensors/s -> 65 tensor/s). The memory savings probably
|
# (70 tensors/s -> 65 tensor/s). The memory savings probably
|
||||||
# shouldn't be the happening, maybe there's a memory leak
|
# shouldn't be the happening, maybe there's a memory leak
|
||||||
# somewhere in our pipeline with CPU tensors.
|
# somewhere in our pipeline with CPU tensors.
|
||||||
intermediary_device = "cuda"
|
intermediary_device = "cuda:0"
|
||||||
else:
|
else:
|
||||||
intermediary_device = "cpu"
|
intermediary_device = "cpu"
|
||||||
|
|
||||||
@@ -409,6 +405,7 @@ def patch_safetensors(callback):
|
|||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
transformers.modeling_utils.safe_load_file = safetensors_load
|
transformers.modeling_utils.safe_load_file = safetensors_load
|
||||||
|
safetensors.torch.load_file = safetensors_load
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
@@ -129,15 +129,33 @@ def patch_transformers_generation() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class LazyloadPatches:
|
class LazyloadPatches:
|
||||||
|
class StateDictFacade(dict):
|
||||||
|
def __init__(self, state_dict):
|
||||||
|
self.update(state_dict)
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
return super().__getitem__(name).materialize(map_location="cuda:0")
|
||||||
|
|
||||||
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
|
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
|
||||||
|
torch_old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||||
|
|
||||||
def __enter__() -> None:
|
def __enter__() -> None:
|
||||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||||
LazyloadPatches._load_state_dict_into_meta_model
|
LazyloadPatches._load_state_dict_into_meta_model
|
||||||
)
|
)
|
||||||
|
torch.nn.Module._load_from_state_dict = LazyloadPatches._torch_load_from_state_dict
|
||||||
|
|
||||||
def __exit__(exc_type, exc_value, exc_traceback) -> None:
|
def __exit__(exc_type, exc_value, exc_traceback) -> None:
|
||||||
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
|
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
|
||||||
|
torch.nn.Module._load_from_state_dict = LazyloadPatches.torch_old_load_from_state_dict
|
||||||
|
|
||||||
|
def _torch_load_from_state_dict(self, state_dict, *args, **kwargs):
|
||||||
|
return LazyloadPatches.torch_old_load_from_state_dict(
|
||||||
|
self,
|
||||||
|
LazyloadPatches.StateDictFacade(state_dict),
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _load_state_dict_into_meta_model(
|
def _load_state_dict_into_meta_model(
|
||||||
model,
|
model,
|
||||||
|
Reference in New Issue
Block a user