diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 56ffce60..1997e7fe 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -498,6 +498,7 @@ class HFTorchInferenceModel(HFInferenceModel): utils.get_sharded_checkpoint_num_tensors( utils.from_pretrained_model_name, utils.from_pretrained_index_filename, + is_safetensors=is_safetensors, **utils.from_pretrained_kwargs, ) ) diff --git a/utils.py b/utils.py index fba8186c..3a9a884c 100644 --- a/utils.py +++ b/utils.py @@ -569,13 +569,19 @@ def get_num_shards(filename): #==================================================================# # Given the name/path of a sharded model and the path to a # pytorch_model.bin.index.json, returns a list of weight names in the -# sharded model. Requires lazy loader to be enabled to work properl +# sharded model. Requires lazy loader to be enabled to work properly #==================================================================# -def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): +def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, is_safetensors=False, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs): import transformers.modeling_utils - import torch _revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision) + + if is_safetensors: + from safetensors import safe_open + return list(itertools.chain(*(safe_open(p, framework="pt", device="cpu").keys() for p in shard_paths))) + + # Torch + import torch return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths))) #==================================================================#