Support safetensors in get_sharded_checkpoint_num_tensors

This commit is contained in:
somebody
2023-05-01 19:15:27 -05:00
parent 97e84928ba
commit f6b5548131
2 changed files with 10 additions and 3 deletions

View File

@@ -498,6 +498,7 @@ class HFTorchInferenceModel(HFInferenceModel):
utils.get_sharded_checkpoint_num_tensors( utils.get_sharded_checkpoint_num_tensors(
utils.from_pretrained_model_name, utils.from_pretrained_model_name,
utils.from_pretrained_index_filename, utils.from_pretrained_index_filename,
is_safetensors=is_safetensors,
**utils.from_pretrained_kwargs, **utils.from_pretrained_kwargs,
) )
) )

View File

@@ -569,13 +569,19 @@ def get_num_shards(filename):
#==================================================================# #==================================================================#
# Given the name/path of a sharded model and the path to a # Given the name/path of a sharded model and the path to a
# pytorch_model.bin.index.json, returns a list of weight names in the # pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl # sharded model. Requires lazy loader to be enabled to work properly
#==================================================================# #==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, is_safetensors=False, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
import transformers.modeling_utils import transformers.modeling_utils
import torch
_revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION _revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision) shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
if is_safetensors:
from safetensors import safe_open
return list(itertools.chain(*(safe_open(p, framework="pt", device="cpu").keys() for p in shard_paths)))
# Torch
import torch
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
#==================================================================# #==================================================================#