mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Support safetensors in get_sharded_checkpoint_num_tensors
This commit is contained in:
@@ -498,6 +498,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
utils.get_sharded_checkpoint_num_tensors(
|
utils.get_sharded_checkpoint_num_tensors(
|
||||||
utils.from_pretrained_model_name,
|
utils.from_pretrained_model_name,
|
||||||
utils.from_pretrained_index_filename,
|
utils.from_pretrained_index_filename,
|
||||||
|
is_safetensors=is_safetensors,
|
||||||
**utils.from_pretrained_kwargs,
|
**utils.from_pretrained_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
12
utils.py
12
utils.py
@@ -569,13 +569,19 @@ def get_num_shards(filename):
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Given the name/path of a sharded model and the path to a
|
# Given the name/path of a sharded model and the path to a
|
||||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||||
# sharded model. Requires lazy loader to be enabled to work properl
|
# sharded model. Requires lazy loader to be enabled to work properly
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, is_safetensors=False, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
import torch
|
|
||||||
_revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
_revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
|
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=_revision)
|
||||||
|
|
||||||
|
if is_safetensors:
|
||||||
|
from safetensors import safe_open
|
||||||
|
return list(itertools.chain(*(safe_open(p, framework="pt", device="cpu").keys() for p in shard_paths)))
|
||||||
|
|
||||||
|
# Torch
|
||||||
|
import torch
|
||||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
Reference in New Issue
Block a user