Lazy loader can now use accelerate's init_empty_weights()

This commit is contained in:
Gnome Ann 2022-06-16 18:56:16 -04:00
parent 5253cdcb36
commit 8bdf17f598
2 changed files with 29 additions and 20 deletions

View File

@ -1788,7 +1788,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_'))) shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True) print("\n", flush=True)
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True): with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
try: try:
metamodel = AutoModelForCausalLM.from_config(model_config) metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e: except Exception as e:

View File

@ -50,6 +50,7 @@ import itertools
import zipfile import zipfile
import pickle import pickle
import torch import torch
import utils
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
@contextlib.contextmanager @contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False): def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable: if not enable:
yield False yield False
return return
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch.load = torch_load torch.load = torch_load
if dematerialized_modules: if dematerialized_modules:
old_linear_init = torch.nn.Linear.__init__ if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
old_embedding_init = torch.nn.Embedding.__init__ import accelerate
old_layernorm_init = torch.nn.LayerNorm.__init__ init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs): def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs) return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs): def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs) return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs): def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs) return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict torch.nn.Module._load_from_state_dict = _load_from_state_dict
yield True yield True
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch._utils._rebuild_tensor = old_rebuild_tensor torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load torch.load = old_torch_load
if dematerialized_modules: if dematerialized_modules:
torch.nn.Linear.__init__ = old_linear_init if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
torch.nn.Embedding.__init__ = old_embedding_init init_empty_weights.__exit__(None, None, None)
torch.nn.LayerNorm.__init__ = old_layernorm_init else:
torch.nn.Module._load_from_state_dict = old_load_from_state_dict torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict