mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Lazy loader can now use accelerate's init_empty_weights()
This commit is contained in:
@ -50,6 +50,7 @@ import itertools
|
||||
import zipfile
|
||||
import pickle
|
||||
import torch
|
||||
import utils
|
||||
from torch.nn import Module
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
|
||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
||||
if not enable:
|
||||
yield False
|
||||
return
|
||||
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
||||
torch.load = torch_load
|
||||
|
||||
if dematerialized_modules:
|
||||
old_linear_init = torch.nn.Linear.__init__
|
||||
old_embedding_init = torch.nn.Embedding.__init__
|
||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
|
||||
import accelerate
|
||||
init_empty_weights = accelerate.init_empty_weights()
|
||||
init_empty_weights.__enter__()
|
||||
else:
|
||||
old_linear_init = torch.nn.Linear.__init__
|
||||
old_embedding_init = torch.nn.Embedding.__init__
|
||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||
|
||||
def linear_init(self, *args, device=None, **kwargs):
|
||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||
def linear_init(self, *args, device=None, **kwargs):
|
||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
def embedding_init(self, *args, device=None, **kwargs):
|
||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||
def embedding_init(self, *args, device=None, **kwargs):
|
||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
def layernorm_init(self, *args, device=None, **kwargs):
|
||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||
def layernorm_init(self, *args, device=None, **kwargs):
|
||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||
|
||||
torch.nn.Linear.__init__ = linear_init
|
||||
torch.nn.Embedding.__init__ = embedding_init
|
||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||
torch.nn.Linear.__init__ = linear_init
|
||||
torch.nn.Embedding.__init__ = embedding_init
|
||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||
|
||||
yield True
|
||||
|
||||
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
if dematerialized_modules:
|
||||
torch.nn.Linear.__init__ = old_linear_init
|
||||
torch.nn.Embedding.__init__ = old_embedding_init
|
||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
||||
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
|
||||
init_empty_weights.__exit__(None, None, None)
|
||||
else:
|
||||
torch.nn.Linear.__init__ = old_linear_init
|
||||
torch.nn.Embedding.__init__ = old_embedding_init
|
||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
||||
|
Reference in New Issue
Block a user