Fix `no attribute get_checkpoint_shard_files`

This commit is contained in:
Gnome Ann 2022-05-14 11:49:04 -04:00
parent 6e82f205b4
commit d5ab3ef5b1
1 changed files with 14 additions and 12 deletions

View File

@ -1180,12 +1180,13 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
utils.num_shards = utils.get_num_shards(index_filename) def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.from_pretrained_index_filename = index_filename utils.num_shards = utils.get_num_shards(index_filename)
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) utils.from_pretrained_index_filename = index_filename
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
# Lazy loader # Lazy loader
import torch_lazy_loader import torch_lazy_loader
@ -1707,12 +1708,13 @@ else:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
utils.num_shards = utils.get_num_shards(index_filename) def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.from_pretrained_index_filename = index_filename utils.num_shards = utils.get_num_shards(index_filename)
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) utils.from_pretrained_index_filename = index_filename
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
def tpumtjgetsofttokens(): def tpumtjgetsofttokens():
soft_tokens = None soft_tokens = None