mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-20 05:30:57 +01:00
Merge pull request #155 from VE-FORBRYDERNE/accelerate
Initial support for Accelerate
This commit is contained in:
commit
efed44ac8d
142
aiserver.py
142
aiserver.py
@ -610,6 +610,24 @@ def move_model_to_devices(model):
|
|||||||
|
|
||||||
model.half()
|
model.half()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
if(utils.HAS_ACCELERATE):
|
||||||
|
import accelerate
|
||||||
|
gpu_blocks = breakmodel.gpu_blocks
|
||||||
|
ram_blocks = len(vars.layers_module_names) - sum(gpu_blocks)
|
||||||
|
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
|
||||||
|
device_map = {}
|
||||||
|
for name in vars.layers_module_names:
|
||||||
|
layer = int(name.rsplit(".", 1)[1])
|
||||||
|
device = "cpu" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
|
||||||
|
device_map[name] = device
|
||||||
|
for name in utils.get_missing_module_names(model, list(device_map.keys())):
|
||||||
|
device_map[name] = breakmodel.primary_device
|
||||||
|
accelerate.dispatch_model(model, device_map, main_device=breakmodel.primary_device)
|
||||||
|
gc.collect()
|
||||||
|
generator = model.generate
|
||||||
|
return
|
||||||
|
|
||||||
if(hasattr(model, "transformer")):
|
if(hasattr(model, "transformer")):
|
||||||
model.transformer.wte.to(breakmodel.primary_device)
|
model.transformer.wte.to(breakmodel.primary_device)
|
||||||
model.transformer.ln_f.to(breakmodel.primary_device)
|
model.transformer.ln_f.to(breakmodel.primary_device)
|
||||||
@ -1192,8 +1210,37 @@ def get_oai_models(key):
|
|||||||
print("{0}ERROR!{1}".format(colors.RED, colors.END))
|
print("{0}ERROR!{1}".format(colors.RED, colors.END))
|
||||||
print(req.json())
|
print(req.json())
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
||||||
|
|
||||||
|
|
||||||
|
# Function to patch transformers to use our soft prompt
|
||||||
|
def patch_causallm(cls):
|
||||||
|
if(getattr(cls, "_koboldai_patch_causallm_patched", False)):
|
||||||
|
return
|
||||||
|
old_forward = cls.forward
|
||||||
|
def new_causallm_forward(self, *args, **kwargs):
|
||||||
|
input_ids = kwargs.get('input_ids').to(self.device)
|
||||||
|
assert input_ids is not None
|
||||||
|
kwargs['input_ids'] = None
|
||||||
|
if(vars.sp is not None):
|
||||||
|
shifted_input_ids = input_ids - self.config.vocab_size
|
||||||
|
input_ids.clamp_(max=self.config.vocab_size-1)
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
if(vars.sp is not None):
|
||||||
|
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||||
|
inputs_embeds = torch.where(
|
||||||
|
(shifted_input_ids >= 0)[..., None],
|
||||||
|
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
|
||||||
|
inputs_embeds *= self.model.embed_scale
|
||||||
|
kwargs['inputs_embeds'] = inputs_embeds
|
||||||
|
return old_forward(self, *args, **kwargs)
|
||||||
|
cls.forward = new_causallm_forward
|
||||||
|
cls._koboldai_patch_causallm_patched = True
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def patch_transformers():
|
def patch_transformers():
|
||||||
global transformers
|
global transformers
|
||||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
@ -1241,42 +1288,6 @@ def patch_transformers():
|
|||||||
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||||
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
XGLMSinusoidalPositionalEmbedding.forward = new_forward
|
||||||
|
|
||||||
# Patch transformers to use our soft prompt
|
|
||||||
def patch_causallm(cls):
|
|
||||||
old_forward = cls.forward
|
|
||||||
def new_causallm_forward(self, *args, **kwargs):
|
|
||||||
input_ids = kwargs.get('input_ids').to(self.device)
|
|
||||||
assert input_ids is not None
|
|
||||||
kwargs['input_ids'] = None
|
|
||||||
if(vars.sp is not None):
|
|
||||||
shifted_input_ids = input_ids - self.config.vocab_size
|
|
||||||
input_ids.clamp_(max=self.config.vocab_size-1)
|
|
||||||
if(hasattr(self, "transformer")):
|
|
||||||
inputs_embeds = self.transformer.wte(input_ids)
|
|
||||||
elif(not hasattr(self.model, "decoder")):
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
||||||
else:
|
|
||||||
inputs_embeds = self.model.decoder.embed_tokens(input_ids)
|
|
||||||
if(vars.sp is not None):
|
|
||||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
|
||||||
inputs_embeds = torch.where(
|
|
||||||
(shifted_input_ids >= 0)[..., None],
|
|
||||||
vars.sp[shifted_input_ids.clamp(min=0)],
|
|
||||||
inputs_embeds,
|
|
||||||
)
|
|
||||||
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
|
|
||||||
inputs_embeds *= self.model.embed_scale
|
|
||||||
kwargs['inputs_embeds'] = inputs_embeds
|
|
||||||
return old_forward(self, *args, **kwargs)
|
|
||||||
cls.forward = new_causallm_forward
|
|
||||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
|
||||||
patch_causallm(cls)
|
|
||||||
for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"):
|
|
||||||
try:
|
|
||||||
patch_causallm(getattr(__import__("transformers"), c))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
|
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
|
||||||
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
|
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
|
||||||
@ -1563,7 +1574,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
loadsettings()
|
loadsettings()
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
vars.hascuda = torch.cuda.is_available()
|
vars.hascuda = torch.cuda.is_available()
|
||||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel
|
vars.bmsupported = (utils.HAS_ACCELERATE or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel
|
||||||
if(args.breakmodel is not None and args.breakmodel):
|
if(args.breakmodel is not None and args.breakmodel):
|
||||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
||||||
if(args.breakmodel_layers is not None):
|
if(args.breakmodel_layers is not None):
|
||||||
@ -1657,24 +1668,20 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
else:
|
else:
|
||||||
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
||||||
|
|
||||||
def lazy_load_callback(model_dict, f, **_):
|
def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
|
||||||
if lazy_load_callback.nested:
|
if lazy_load_callback.nested:
|
||||||
return
|
return
|
||||||
lazy_load_callback.nested = True
|
lazy_load_callback.nested = True
|
||||||
|
|
||||||
device_map = {}
|
device_map: Dict[str, Union[str, int]] = {}
|
||||||
|
|
||||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
|
||||||
for layer in range(n_layers):
|
|
||||||
key = _key.format(layer=layer)
|
|
||||||
if key not in model_dict:
|
|
||||||
continue
|
|
||||||
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel or layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
|
|
||||||
device_map[key] = device
|
|
||||||
|
|
||||||
for key, value in model_dict.items():
|
for key, value in model_dict.items():
|
||||||
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
|
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(key.startswith(n) or key.startswith(n.split(".", 1)[1]) for n in vars.layers_module_names):
|
||||||
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
|
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else breakmodel.primary_device
|
||||||
|
else:
|
||||||
|
layer = int(max((n for n in vars.layers_module_names if key.startswith(n) or key.startswith(n.split(".", 1)[1])), key=len).rsplit(".", 1)[1])
|
||||||
|
device = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else "shared" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
|
||||||
|
device_map[key] = device
|
||||||
|
|
||||||
if utils.num_shards is None or utils.current_shard == 0:
|
if utils.num_shards is None or utils.current_shard == 0:
|
||||||
if utils.num_shards is not None:
|
if utils.num_shards is not None:
|
||||||
@ -1689,6 +1696,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
last_storage_key = None
|
last_storage_key = None
|
||||||
f = None
|
f = None
|
||||||
current_offset = 0
|
current_offset = 0
|
||||||
|
able_to_pin_layers = True
|
||||||
if utils.num_shards is not None:
|
if utils.num_shards is not None:
|
||||||
utils.current_shard += 1
|
utils.current_shard += 1
|
||||||
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||||
@ -1714,7 +1722,15 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
model_dict[key] = model_dict[key].to(torch.float16)
|
model_dict[key] = model_dict[key].to(torch.float16)
|
||||||
if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16:
|
if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16:
|
||||||
model_dict[key] = model_dict[key].to(torch.float32)
|
model_dict[key] = model_dict[key].to(torch.float32)
|
||||||
model_dict[key] = model_dict[key].to(device)
|
if device == "shared":
|
||||||
|
model_dict[key] = model_dict[key].to("cpu").detach_()
|
||||||
|
if able_to_pin_layers and utils.HAS_ACCELERATE:
|
||||||
|
try:
|
||||||
|
model_dict[key] = model_dict[key].pin_memory()
|
||||||
|
except:
|
||||||
|
able_to_pin_layers = False
|
||||||
|
else:
|
||||||
|
model_dict[key] = model_dict[key].to(device)
|
||||||
#print("OK", flush=True)
|
#print("OK", flush=True)
|
||||||
current_offset += nbytes
|
current_offset += nbytes
|
||||||
utils.bar.update(1)
|
utils.bar.update(1)
|
||||||
@ -1729,15 +1745,6 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
lazy_load_callback.nested = False
|
lazy_load_callback.nested = False
|
||||||
return lazy_load_callback
|
return lazy_load_callback
|
||||||
|
|
||||||
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
|
|
||||||
if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)):
|
|
||||||
with open(lazy_load_config_path) as f:
|
|
||||||
lazy_load_spec = json.load(f)
|
|
||||||
|
|
||||||
else:
|
|
||||||
vars.lazy_load = False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_size_from_model(model):
|
def get_hidden_size_from_model(model):
|
||||||
try:
|
try:
|
||||||
@ -1791,6 +1798,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
else:
|
else:
|
||||||
model = model.to('cpu').float()
|
model = model.to('cpu').float()
|
||||||
generator = model.generate
|
generator = model.generate
|
||||||
|
patch_causallm(model.__class__)
|
||||||
# Use the Generic implementation
|
# Use the Generic implementation
|
||||||
else:
|
else:
|
||||||
lowmem = maybe_low_cpu_mem_usage()
|
lowmem = maybe_low_cpu_mem_usage()
|
||||||
@ -1799,6 +1807,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
# feature yet
|
# feature yet
|
||||||
if(vars.model_type == "gpt2"):
|
if(vars.model_type == "gpt2"):
|
||||||
lowmem = {}
|
lowmem = {}
|
||||||
|
vars.lazy_load = False # Also, lazy loader doesn't support GPT-2 models
|
||||||
|
|
||||||
# If we're using torch_lazy_loader, we need to get breakmodel config
|
# If we're using torch_lazy_loader, we need to get breakmodel config
|
||||||
# early so that it knows where to load the individual model tensors
|
# early so that it knows where to load the individual model tensors
|
||||||
@ -1812,6 +1821,13 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
import shutil
|
import shutil
|
||||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||||
print("\n", flush=True)
|
print("\n", flush=True)
|
||||||
|
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||||
|
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
|
||||||
|
try:
|
||||||
|
metamodel = AutoModelForCausalLM.from_config(model_config)
|
||||||
|
except Exception as e:
|
||||||
|
metamodel = GPTNeoForCausalLM.from_config(model_config)
|
||||||
|
vars.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||||
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
|
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
|
||||||
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||||
lowmem = {}
|
lowmem = {}
|
||||||
@ -1910,7 +1926,9 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
|
|||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
||||||
shutil.rmtree("cache/")
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
|
patch_causallm(model.__class__)
|
||||||
|
|
||||||
if(vars.hascuda):
|
if(vars.hascuda):
|
||||||
if(vars.usegpu):
|
if(vars.usegpu):
|
||||||
vars.modeldim = get_hidden_size_from_model(model)
|
vars.modeldim = get_hidden_size_from_model(model)
|
||||||
|
@ -50,6 +50,7 @@ import itertools
|
|||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
|
import utils
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@ -213,7 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False):
|
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
||||||
if not enable:
|
if not enable:
|
||||||
yield False
|
yield False
|
||||||
return
|
return
|
||||||
@ -236,24 +237,29 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
|||||||
torch.load = torch_load
|
torch.load = torch_load
|
||||||
|
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
old_linear_init = torch.nn.Linear.__init__
|
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
|
||||||
old_embedding_init = torch.nn.Embedding.__init__
|
import accelerate
|
||||||
old_layernorm_init = torch.nn.LayerNorm.__init__
|
init_empty_weights = accelerate.init_empty_weights()
|
||||||
|
init_empty_weights.__enter__()
|
||||||
|
else:
|
||||||
|
old_linear_init = torch.nn.Linear.__init__
|
||||||
|
old_embedding_init = torch.nn.Embedding.__init__
|
||||||
|
old_layernorm_init = torch.nn.LayerNorm.__init__
|
||||||
|
|
||||||
def linear_init(self, *args, device=None, **kwargs):
|
def linear_init(self, *args, device=None, **kwargs):
|
||||||
return old_linear_init(self, *args, device="meta", **kwargs)
|
return old_linear_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
def embedding_init(self, *args, device=None, **kwargs):
|
def embedding_init(self, *args, device=None, **kwargs):
|
||||||
return old_embedding_init(self, *args, device="meta", **kwargs)
|
return old_embedding_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
def layernorm_init(self, *args, device=None, **kwargs):
|
def layernorm_init(self, *args, device=None, **kwargs):
|
||||||
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
return old_layernorm_init(self, *args, device="meta", **kwargs)
|
||||||
|
|
||||||
torch.nn.Linear.__init__ = linear_init
|
torch.nn.Linear.__init__ = linear_init
|
||||||
torch.nn.Embedding.__init__ = embedding_init
|
torch.nn.Embedding.__init__ = embedding_init
|
||||||
torch.nn.LayerNorm.__init__ = layernorm_init
|
torch.nn.LayerNorm.__init__ = layernorm_init
|
||||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||||
|
|
||||||
yield True
|
yield True
|
||||||
|
|
||||||
@ -262,7 +268,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
|||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
torch.nn.Linear.__init__ = old_linear_init
|
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
|
||||||
torch.nn.Embedding.__init__ = old_embedding_init
|
init_empty_weights.__exit__(None, None, None)
|
||||||
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
else:
|
||||||
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
torch.nn.Linear.__init__ = old_linear_init
|
||||||
|
torch.nn.Embedding.__init__ = old_embedding_init
|
||||||
|
torch.nn.LayerNorm.__init__ = old_layernorm_init
|
||||||
|
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
|
||||||
|
61
utils.py
61
utils.py
@ -7,10 +7,19 @@ import tempfile
|
|||||||
import requests
|
import requests
|
||||||
import requests.adapters
|
import requests.adapters
|
||||||
import time
|
import time
|
||||||
|
from transformers import __version__ as transformers_version
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
import packaging.version
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import os
|
import os
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
|
||||||
|
try:
|
||||||
|
import accelerate
|
||||||
|
except ImportError:
|
||||||
|
HAS_ACCELERATE = False
|
||||||
|
|
||||||
vars = None
|
vars = None
|
||||||
num_shards: Optional[int] = None
|
num_shards: Optional[int] = None
|
||||||
@ -300,3 +309,53 @@ def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename,
|
|||||||
import torch
|
import torch
|
||||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
||||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Given a PreTrainedModel, returns the list of module names that correspond
|
||||||
|
# to the model's hidden layers.
|
||||||
|
#==================================================================#
|
||||||
|
def get_layers_module_names(model: PreTrainedModel) -> List[str]:
|
||||||
|
names: List[str] = []
|
||||||
|
def recurse(module, head=""):
|
||||||
|
for c in module.named_children():
|
||||||
|
name = head + c[0]
|
||||||
|
if c[0].isnumeric() and any(c[1].__class__.__name__.endswith(suffix) for suffix in ("Block", "Layer")):
|
||||||
|
names.append(name)
|
||||||
|
else:
|
||||||
|
recurse(c[1], head=name + ".")
|
||||||
|
recurse(model)
|
||||||
|
return names
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Given a PreTrainedModel, returns the module name that corresponds
|
||||||
|
# to the model's input embeddings.
|
||||||
|
#==================================================================#
|
||||||
|
def get_input_embeddings_module_name(model: PreTrainedModel) -> str:
|
||||||
|
embeddings = model.get_input_embeddings()
|
||||||
|
def recurse(module, head=""):
|
||||||
|
for c in module.named_children():
|
||||||
|
name = head + c[0]
|
||||||
|
if c[1] is embeddings:
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
return recurse(c[1], head=name + ".")
|
||||||
|
return recurse(model)
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Given a PreTrainedModel and a list of module names, returns a list
|
||||||
|
# of module names such that the union of the set of modules given as input
|
||||||
|
# and the set of modules returned as output contains all modules in the model.
|
||||||
|
#==================================================================#
|
||||||
|
def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[str]:
|
||||||
|
missing_names: List[str] = []
|
||||||
|
def recurse(module, head=""):
|
||||||
|
for c in module.named_children():
|
||||||
|
name = head + c[0]
|
||||||
|
if any(name.startswith(n) for n in names):
|
||||||
|
continue
|
||||||
|
if next(c[1].named_children(), None) is None:
|
||||||
|
missing_names.append(name)
|
||||||
|
else:
|
||||||
|
recurse(c[1], head=name + ".")
|
||||||
|
recurse(model)
|
||||||
|
return missing_names
|
||||||
|
Loading…
x
Reference in New Issue
Block a user