Copy of VE's 8 bit lazyloading/breakmodel code

This commit is contained in:
ebolam
2022-11-30 18:56:22 -05:00
parent 06a25b4663
commit e278abd7c9
2 changed files with 33 additions and 17 deletions

View File

@@ -2465,10 +2465,6 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if not koboldai_vars.bit_8_available or not koboldai_vars.experimental_features:
use_8_bit = False
if use_8_bit:
koboldai_vars.lazy_load = False
koboldai_vars.breakmodel = False
logger.info("koboldai_vars.lazy_load: {}".format(koboldai_vars.lazy_load))
if(initial_load):
use_breakmodel_args = True
if not utils.HAS_ACCELERATE:
@@ -2680,7 +2676,6 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
# Lazy loader
import torch_lazy_loader
def get_lazy_load_callback(n_layers, convert_to_float16=True):
logger.info("In Callback - koboldai_vars.lazy_load: {}".format(koboldai_vars.lazy_load))
if not koboldai_vars.lazy_load:
return
@@ -2742,6 +2737,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
koboldai_vars.loaded_layers = 0
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
if koboldai_vars.bit_8_available:
import bitsandbytes as bnb
with zipfile.ZipFile(f, "r") as z:
try:
last_storage_key = None
@@ -2769,7 +2766,10 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
if model_dict[key].dtype is torch.float32:
koboldai_vars.fp32_model = True
if convert_to_float16 and breakmodel.primary_device != "cpu" and koboldai_vars.hascuda and (koboldai_vars.breakmodel or koboldai_vars.usegpu) and model_dict[key].dtype is torch.float32:
if convert_to_float16 and breakmodel.primary_device != "cpu" and koboldai_vars.hascuda and (koboldai_vars.breakmodel or koboldai_vars.usegpu) and any(model_dict[key].dtype is t for t in (torch.float32, torch.float16)):
if use_8_bit:
model_dict[key] = bnb.nn.Int8Params(model_dict[key].to(torch.float16), requires_grad=False, has_fp16_weights=False).to(device if device not in ("shared", "disk") else "cpu")
else:
model_dict[key] = model_dict[key].to(torch.float16)
if breakmodel.primary_device == "cpu" or (not koboldai_vars.usegpu and not koboldai_vars.breakmodel and model_dict[key].dtype is torch.float16):
model_dict[key] = model_dict[key].to(torch.float32)
@@ -2882,7 +2882,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
import shutil
shutil.move(koboldai_vars.model.replace('/', '_'), "models/{}".format(koboldai_vars.model.replace('/', '_')))
if(koboldai_vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
with torch_lazy_loader.use_lazy_torch_load(bit_8_available=koboldai_vars.bit_8_available, dematerialized_modules=True, use_accelerate_init_empty_weights=True):
try:
metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e:
@@ -2890,7 +2890,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
utils.layers_module_names = utils.get_layers_module_names(metamodel)
utils.module_names = list(metamodel.state_dict().keys())
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=koboldai_vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if koboldai_vars.lazy_load else None, dematerialized_modules=True):
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(bit_8_available=koboldai_vars.bit_8_available, enable=koboldai_vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if koboldai_vars.lazy_load else None, dematerialized_modules=True):
if(koboldai_vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {}
if(os.path.isdir(koboldai_vars.custmodpth)):

View File

@@ -225,6 +225,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
try:
with torch.no_grad():
#param.copy_(input_param)
import bitsandbytes as bnb # This line is new
if isinstance(input_param, bnb.nn.Int8Params): # This line is new
new_param = input_param # This line is new
else: # This line is new
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
@@ -277,7 +281,7 @@ def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler
pickle.load = old_pickle_load
@contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable:
with use_custom_unpickler(RestrictedUnpickler):
yield False
@@ -298,17 +302,30 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch.load = torch_load
if dematerialized_modules:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
import accelerate
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
import accelerate
if bit_8_available:
import bitsandbytes as bnb
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
if linear_init.nested_flag or not bit_8_available:
return old_linear_init(self, *args, device=device, **kwargs)
linear_init.nested_flag = True
try:
self.__class__ = bnb.nn.Linear8bitLt
with accelerate.init_empty_weights():
return bnb.nn.Linear8bitLt.__init__(self, *args, has_fp16_weights=False, threshold=6.0, **kwargs)
finally:
linear_init.nested_flag = False
linear_init.nested_flag = False
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
@@ -319,7 +336,6 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
with use_custom_unpickler(_LazyUnpickler):