Add more model querying utilities

This commit is contained in:
Gnome Ann
2022-06-18 18:16:56 -04:00
parent e143963161
commit f7ffdd7b6b
2 changed files with 43 additions and 4 deletions

View File

@ -1805,7 +1805,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model="
metamodel = AutoModelForCausalLM.from_config(model_config) metamodel = AutoModelForCausalLM.from_config(model_config)
except Exception as e: except Exception as e:
metamodel = GPTNeoForCausalLM.from_config(model_config) metamodel = GPTNeoForCausalLM.from_config(model_config)
vars.layer_param_names = utils.get_layer_param_names(metamodel) vars.layer_param_names = utils.get_layers_module_names(metamodel)
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True): with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
lowmem = {} lowmem = {}

View File

@ -8,11 +8,12 @@ import requests
import requests.adapters import requests.adapters
import time import time
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
from transformers import PreTrainedModel
import packaging.version import packaging.version
from tqdm.auto import tqdm from tqdm.auto import tqdm
import os import os
import itertools import itertools
from typing import Optional from typing import List, Optional
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0") HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
try: try:
@ -309,8 +310,12 @@ def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename,
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
def get_layer_param_names(model): #==================================================================#
names = [] # Given a PreTrainedModel, returns the list of module names that correspond
# to the model's hidden layers.
#==================================================================#
def get_layers_module_names(model: PreTrainedModel) -> List[str]:
names: List[str] = []
def recurse(module, head=""): def recurse(module, head=""):
for c in module.named_children(): for c in module.named_children():
name = head + c[0] name = head + c[0]
@ -320,3 +325,37 @@ def get_layer_param_names(model):
recurse(c[1], head=name + ".") recurse(c[1], head=name + ".")
recurse(model) recurse(model)
return names return names
#==================================================================#
# Given a PreTrainedModel, returns the module name that corresponds
# to the model's input embeddings.
#==================================================================#
def get_input_embeddings_module_name(model: PreTrainedModel) -> str:
embeddings = model.get_input_embeddings()
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if c[1] is embeddings:
return name
else:
return recurse(c[1], head=name + ".")
return recurse(model)
#==================================================================#
# Given a PreTrainedModel and a list of module names, returns a list
# of module names such that the union of the set of modules given as input
# and the set of modules returned as output contains all modules in the model.
#==================================================================#
def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[str]:
missing_names: List[str] = []
def recurse(module, head=""):
for c in module.named_children():
name = head + c[0]
if any(name.startswith(n) for n in names):
continue
if next(c[1].named_children(), None) is None:
missing_names.append(name)
else:
recurse(c[1], head=name + ".")
recurse(model)
return missing_names