Fix patching code of `PreTrainedModel.from_pretrained()`

This commit is contained in:
Gnome Ann 2022-05-11 00:41:53 -04:00
parent 22b4f3c9df
commit f09959f9be
1 changed files with 8 additions and 6 deletions

View File

@ -1114,10 +1114,11 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
from transformers import PreTrainedModel from transformers import PreTrainedModel
old_from_pretrained = PreTrainedModel.from_pretrained old_from_pretrained = PreTrainedModel.from_pretrained.__func__
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
# Lazy loader # Lazy loader
@ -1545,10 +1546,11 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
else: else:
from transformers import PreTrainedModel from transformers import PreTrainedModel
old_from_pretrained = PreTrainedModel.from_pretrained old_from_pretrained = PreTrainedModel.from_pretrained.__func__
def new_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
def tpumtjgetsofttokens(): def tpumtjgetsofttokens():