Allow TPU Colab to load sharded HF models

This commit is contained in:
Gnome Ann
2022-05-12 23:51:40 -04:00
parent 4fa5f1cd6a
commit b1d8797a54
4 changed files with 43 additions and 5 deletions

View File

@ -6,8 +6,11 @@ import subprocess
import tempfile
import requests
import os
from typing import Optional
vars = None
num_shards: Optional[int] = None
current_shard = 0
#==================================================================#
# Decorator to prevent a function's actions from being run until
@ -133,7 +136,7 @@ def decodenewlines(txt):
return txt
#==================================================================#
# Downloads sharded huggingface checkpoints using aria2c if possible
# Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers
@ -225,3 +228,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
json.dump({"url": u, "etag": t}, f)
#==================================================================#
# Given the path to a pytorch_model.bin.index.json, returns how many
# shards there are in the model
#==================================================================#
def get_num_shards(filename):
with open(filename) as f:
map_data = json.load(f)
return len(set(map_data["weight_map"].values()))