mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Support safetensors in get_sharded_checkpoint_num_tensors
This commit is contained in:
@@ -498,6 +498,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
utils.get_sharded_checkpoint_num_tensors(
|
||||
utils.from_pretrained_model_name,
|
||||
utils.from_pretrained_index_filename,
|
||||
is_safetensors=is_safetensors,
|
||||
**utils.from_pretrained_kwargs,
|
||||
)
|
||||
)
|
||||
|
12
utils.py
12
utils.py
@@ -569,13 +569,19 @@ def get_num_shards(filename):
|
||||
#==================================================================#
|
||||
# Given the name/path of a sharded model and the path to a
|
||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||
# sharded model. Requires lazy loader to be enabled to work properl
|
||||
# sharded model. Requires lazy loader to be enabled to work properly
|
||||
#==================================================================#
|
||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, is_safetensors=False, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||
import transformers.modeling_utils
|
||||
import torch
|
||||
_revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
|
||||
|
||||
if is_safetensors:
|
||||
from safetensors import safe_open
|
||||
return list(itertools.chain(*(safe_open(p, framework="pt", device="cpu").keys() for p in shard_paths)))
|
||||
|
||||
# Torch
|
||||
import torch
|
||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||
|
||||
#==================================================================#
|
||||
|
Reference in New Issue
Block a user