diff --git a/aiserver.py b/aiserver.py index 23c8f5e1..f3e1c4ed 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1170,6 +1170,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): utils.num_shards = None utils.current_shard = 0 + utils.from_pretrained_model_name = pretrained_model_name_or_path + utils.from_pretrained_index_filename = None + utils.from_pretrained_kwargs = kwargs + utils.bar = None if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) @@ -1177,6 +1181,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): utils.num_shards = utils.get_num_shards(index_filename) + utils.from_pretrained_index_filename = index_filename return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files @@ -1196,6 +1201,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go ram_blocks = gpu_blocks = cumulative_gpu_blocks = None def lazy_load_callback(model_dict, f, **_): + if lazy_load_callback.nested: + return + lazy_load_callback.nested = True + device_map = {} for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): @@ -1210,6 +1219,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" + if utils.num_shards is None or utils.current_shard == 0: + if utils.num_shards is not None: + num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) + else: + num_tensors = len(device_map) + utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") + with zipfile.ZipFile(f, "r") as z: try: last_storage_key = None @@ -1217,7 +1233,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go current_offset = 0 if utils.num_shards is not None: utils.current_shard += 1 - for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): + for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): storage_key = model_dict[key].key if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: last_storage_key = storage_key @@ -1241,10 +1257,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go model_dict[key] = model_dict[key].to(device) #print("OK", flush=True) current_offset += nbytes + utils.bar.update(1) finally: + if utils.num_shards is None or utils.current_shard >= utils.num_shards: + utils.bar.close() + utils.bar = None + lazy_load_callback.nested = False if isinstance(f, zipfile.ZipExtFile): f.close() + lazy_load_callback.nested = False return lazy_load_callback lazy_load_config_path = os.path.join("maps", vars.model_type + ".json") @@ -1640,6 +1662,10 @@ else: def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): utils.num_shards = None utils.current_shard = 0 + utils.from_pretrained_model_name = pretrained_model_name_or_path + utils.from_pretrained_index_filename = None + utils.from_pretrained_kwargs = kwargs + utils.bar = None if not args.no_aria2: utils.aria2_hook(pretrained_model_name_or_path, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) @@ -1647,6 +1673,7 @@ else: old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): utils.num_shards = utils.get_num_shards(index_filename) + utils.from_pretrained_index_filename = index_filename return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 2fa149d7..3d7bf735 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo import functools def callback(model_dict, f, **_): + if callback.nested: + return + callback.nested = True with zipfile.ZipFile(f, "r") as z: try: last_storage_key = None @@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo current_offset = 0 if utils.current_shard == 0: print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") + + if utils.num_shards is None or utils.current_shard == 0: + if utils.num_shards is not None: + num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs)) + else: + num_tensors = len(model_dict) + utils.bar = tqdm(total=num_tensors, desc="Loading model tensors") + if utils.num_shards is not None: utils.current_shard += 1 - for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): + for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)): # Some model weights are used by transformers but not by MTJ. # We have to materialize these weights anyways because @@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo # tensors, which don't take up any actual CPU or TPU memory. if key not in model_spec: model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") + utils.bar.update(1) continue storage_key = model_dict[key].key @@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo np.empty(params["cores_per_replica"]), ) + utils.bar.update(1) + if utils.num_shards is not None and utils.current_shard < utils.num_shards: return @@ -1250,9 +1264,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo error = f"{mk} {pk} could not be found in the model checkpoint" print("\n\nERROR: " + error, file=sys.stderr) raise RuntimeError(error) + except: + import traceback + traceback.print_exc() finally: + if utils.num_shards is None or utils.current_shard >= utils.num_shards: + utils.bar.close() + utils.bar = None + callback.nested = False if isinstance(f, zipfile.ZipExtFile): f.close() + callback.nested = False if os.path.isdir(vars.model.replace('/', '_')): import shutil diff --git a/utils.py b/utils.py index 9565eaa4..0e4299de 100644 --- a/utils.py +++ b/utils.py @@ -9,11 +9,16 @@ import requests.adapters import time from tqdm.auto import tqdm import os +import itertools from typing import Optional vars = None num_shards: Optional[int] = None current_shard = 0 +from_pretrained_model_name = "" +from_pretrained_index_filename: Optional[str] = None +from_pretrained_kwargs = {} +bar = None #==================================================================# # Decorator to prevent a function's actions from being run until @@ -280,3 +285,14 @@ def get_num_shards(filename): with open(filename) as f: map_data = json.load(f) return len(set(map_data["weight_map"].values())) + +#==================================================================# +# Given the name/path of a sharded model and the path to a +# pytorch_model.bin.index.json, returns a list of weight names in the +# sharded model. Requires lazy loader to be enabled to work properl +#==================================================================# +def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs): + import transformers.modeling_utils + import torch + shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror) + return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))