mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Loading a sharded model will now display only one progress bar
This commit is contained in:
16
utils.py
16
utils.py
@ -9,11 +9,16 @@ import requests.adapters
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
import os
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
vars = None
|
||||
num_shards: Optional[int] = None
|
||||
current_shard = 0
|
||||
from_pretrained_model_name = ""
|
||||
from_pretrained_index_filename: Optional[str] = None
|
||||
from_pretrained_kwargs = {}
|
||||
bar = None
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
@ -280,3 +285,14 @@ def get_num_shards(filename):
|
||||
with open(filename) as f:
|
||||
map_data = json.load(f)
|
||||
return len(set(map_data["weight_map"].values()))
|
||||
|
||||
#==================================================================#
|
||||
# Given the name/path of a sharded model and the path to a
|
||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||
# sharded model. Requires lazy loader to be enabled to work properl
|
||||
#==================================================================#
|
||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers.modeling_utils
|
||||
import torch
|
||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||
|
Reference in New Issue
Block a user