Loading a sharded model will now display only one progress bar

This commit is contained in:
Gnome Ann
2022-05-13 23:32:16 -04:00
parent f9f1a5f3a9
commit 0c5ca5261e
3 changed files with 67 additions and 2 deletions

View File

@ -9,11 +9,16 @@ import requests.adapters
import time
from tqdm.auto import tqdm
import os
import itertools
from typing import Optional
vars = None
num_shards: Optional[int] = None
current_shard = 0
from_pretrained_model_name = ""
from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {}
bar = None
#==================================================================#
# Decorator to prevent a function's actions from being run until
@ -280,3 +285,14 @@ def get_num_shards(filename):
with open(filename) as f:
map_data = json.load(f)
return len(set(map_data["weight_map"].values()))
#==================================================================#
# Given the name/path of a sharded model and the path to a
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers.modeling_utils
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))