mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Revision Cleanup
This commit is contained in:
10
utils.py
10
utils.py
@@ -286,7 +286,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
if token is None:
|
if token is None:
|
||||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||||
_revision = args.revision if args.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||||
sharded = False
|
sharded = False
|
||||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||||
if use_auth_token:
|
if use_auth_token:
|
||||||
@@ -297,7 +297,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
|
|
||||||
def is_cached(filename):
|
def is_cached(filename):
|
||||||
try:
|
try:
|
||||||
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True, revision=_revision)
|
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True, revision=revision)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -306,7 +306,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return
|
return
|
||||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision)
|
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||||
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||||
break
|
break
|
||||||
if sharded:
|
if sharded:
|
||||||
@@ -316,11 +316,11 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||||
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
|
||||||
with open(map_filename) as f:
|
with open(map_filename) as f:
|
||||||
map_data = json.load(f)
|
map_data = json.load(f)
|
||||||
filenames = set(map_data["weight_map"].values())
|
filenames = set(map_data["weight_map"].values())
|
||||||
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
|
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||||
if not force_download:
|
if not force_download:
|
||||||
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
|
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
|
||||||
if not urls:
|
if not urls:
|
||||||
|
Reference in New Issue
Block a user