From f7ffdd7b6b91616e5f04e9b324aceefb01abb715 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 18 Jun 2022 18:16:56 -0400 Subject: [PATCH] Add more model querying utilities --- aiserver.py | 2 +- utils.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/aiserver.py b/aiserver.py index 11ac5b6f..721ddc3b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1805,7 +1805,7 @@ def load_model(use_gpu=True, gpu_layers=None, initial_load=False, online_model=" metamodel = AutoModelForCausalLM.from_config(model_config) except Exception as e: metamodel = GPTNeoForCausalLM.from_config(model_config) - vars.layer_param_names = utils.get_layer_param_names(metamodel) + vars.layer_param_names = utils.get_layers_module_names(metamodel) with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True): if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time lowmem = {} diff --git a/utils.py b/utils.py index bbb42c52..78b21cad 100644 --- a/utils.py +++ b/utils.py @@ -8,11 +8,12 @@ import requests import requests.adapters import time from transformers import __version__ as transformers_version +from transformers import PreTrainedModel import packaging.version from tqdm.auto import tqdm import os import itertools -from typing import Optional +from typing import List, Optional HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0") try: @@ -309,8 +310,12 @@ def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) -def get_layer_param_names(model): - names = [] +#==================================================================# +# Given a PreTrainedModel, returns the list of module names that correspond +# to the model's hidden layers. +#==================================================================# +def get_layers_module_names(model: PreTrainedModel) -> List[str]: + names: List[str] = [] def recurse(module, head=""): for c in module.named_children(): name = head + c[0] @@ -320,3 +325,37 @@ def get_layer_param_names(model): recurse(c[1], head=name + ".") recurse(model) return names + +#==================================================================# +# Given a PreTrainedModel, returns the module name that corresponds +# to the model's input embeddings. +#==================================================================# +def get_input_embeddings_module_name(model: PreTrainedModel) -> str: + embeddings = model.get_input_embeddings() + def recurse(module, head=""): + for c in module.named_children(): + name = head + c[0] + if c[1] is embeddings: + return name + else: + return recurse(c[1], head=name + ".") + return recurse(model) + +#==================================================================# +# Given a PreTrainedModel and a list of module names, returns a list +# of module names such that the union of the set of modules given as input +# and the set of modules returned as output contains all modules in the model. +#==================================================================# +def get_missing_module_names(model: PreTrainedModel, names: List[str]) -> List[str]: + missing_names: List[str] = [] + def recurse(module, head=""): + for c in module.named_children(): + name = head + c[0] + if any(name.startswith(n) for n in names): + continue + if next(c[1].named_children(), None) is None: + missing_names.append(name) + else: + recurse(c[1], head=name + ".") + recurse(model) + return missing_names