mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-09 00:08:53 +01:00
Lazy loader can now use accelerate's init_empty_weights()
This commit is contained in:
parent
5253cdcb36
commit
8bdf17f598
@ -1788,7 +1788,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||||
print("\n", flush=True)
|
print("\n", flush=True)
|
||||||
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||||
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True):
|
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
|
||||||
try:
|
try:
|
||||||
metamodel = AutoModelForCausalLM.from_config(model_config)
|
metamodel = AutoModelForCausalLM.from_config(model_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -50,6 +50,7 @@ import itertools
|
|||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
|
import utils
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
|
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
||||||
if not enable:
|
if not enable:
|
||||||
yield False
|
yield False
|
||||||
return
|
return
|
||||||
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
|||||||
torch.load = torch_load
|
torch.load = torch_load
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
old_linear_init = torch.nn.Linear.__init__
|
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
|
||||||
old_embedding_init = torch.nn.Embedding.__init__
|
import accelerate
|
||||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
init_empty_weights = accelerate.init_empty_weights()
|
||||||
|
init_empty_weights.__enter__()
|
||||||
|
else:
|
||||||
|
old_linear_init = torch.nn.Linear.__init__
|
||||||
|
old_embedding_init = torch.nn.Embedding.__init__
|
||||||
|
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||||
|
|
||||||
def linear_init(self, *args, device=None, **kwargs):
|
def linear_init(self, *args, device=None, **kwargs):
|
||||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
def embedding_init(self, *args, device=None, **kwargs):
|
def embedding_init(self, *args, device=None, **kwargs):
|
||||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
def layernorm_init(self, *args, device=None, **kwargs):
|
def layernorm_init(self, *args, device=None, **kwargs):
|
||||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
torch.nn.Linear.__init__ = linear_init
|
torch.nn.Linear.__init__ = linear_init
|
||||||
torch.nn.Embedding.__init__ = embedding_init
|
torch.nn.Embedding.__init__ = embedding_init
|
||||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||||
|
|
||||||
yield True
|
yield True
|
||||||
|
|
||||||
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
|||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
torch.nn.Linear.__init__ = old_linear_init
|
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
|
||||||
torch.nn.Embedding.__init__ = old_embedding_init
|
init_empty_weights.__exit__(None, None, None)
|
||||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
else:
|
||||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
torch.nn.Linear.__init__ = old_linear_init
|
||||||
|
torch.nn.Embedding.__init__ = old_embedding_init
|
||||||
|
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||||
|
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
||||||
|
Loading…
x
Reference in New Issue
Block a user