aria2_hook now uses new cache format if you have transformers 4.22

This commit is contained in:
vfbd
2022-09-15 16:50:43 -04:00
parent 7bf6c9a23f
commit 463bf86bcc
2 changed files with 267 additions and 49 deletions

View File

@ -1713,11 +1713,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
import transformers.configuration_utils
import transformers.modeling_utils
import transformers.file_utils
import huggingface_hub
legacy = packaging.version.parse(transformers_version) < packaging.version.parse("4.22.0.dev0")
# Save the config.json
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
if(utils.num_shards is None):
# Save the pytorch_model.bin of an unsharded model
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
else:
with open(utils.from_pretrained_index_filename) as f:
map_data = json.load(f)
@ -1726,7 +1728,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
# Then save the pytorch_model-#####-of-#####.bin files
for filename in filenames:
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, filename, revision=vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.rmtree("cache/")
if(vars.hascuda):