Lazy loader can now use accelerate's `init_empty_weights()`

This commit is contained in:
Gnome Ann 2022-06-16 18:56:16 -04:00
parent 5253cdcb36
commit 8bdf17f598
2 changed files with 29 additions and 20 deletions

View File

@ -1788,7 +1788,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
print("\n", flush=True)
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True):
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
try:
metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e:

View File

@ -50,6 +50,7 @@ import itertools
import zipfile
import pickle
import torch
import utils
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
@contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable:
yield False
return
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch.load = torch_load
if dematerialized_modules:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
import accelerate
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict
yield True
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load
if dematerialized_modules:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict