Use aria2 to download split checkpoints

This commit is contained in:
Gnome Ann
2022-05-10 21:28:13 -04:00
parent 7fcc1a9acb
commit a388c63023
3 changed files with 106 additions and 2 deletions

View File

@ -1111,6 +1111,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
import transformers.generation_utils
from transformers import __version__ as transformers_version
from transformers import PreTrainedModel
old_from_pretrained = PreTrainedModel.from_pretrained
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained
# Lazy loader
import torch_lazy_loader
def get_lazy_load_callback(n_layers, convert_to_float16=True):
@ -1535,6 +1542,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
else:
from transformers import PreTrainedModel
old_from_pretrained = PreTrainedModel.from_pretrained
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained
def tpumtjgetsofttokens():
soft_tokens = None
if(vars.sp is None):