mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Revision Cleanup
This commit is contained in:
10
utils.py
10
utils.py
@@ -286,7 +286,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
sharded = False
|
||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||
if use_auth_token:
|
||||
@@ -297,7 +297,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
||||
|
||||
def is_cached(filename):
|
||||
try:
|
||||
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True, revision=_revision)
|
||||
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True, revision=revision)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
@@ -306,7 +306,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||
except AttributeError:
|
||||
return
|
||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision)
|
||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||
break
|
||||
if sharded:
|
||||
@@ -316,11 +316,11 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
|
||||
with open(map_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
|
||||
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||
if not force_download:
|
||||
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
|
||||
if not urls:
|
||||
|
Reference in New Issue
Block a user