Merge pull request #128 from VE-FORBRYDERNE/opt

OPT breakmodel and TPU support
This commit is contained in:
henk717
2022-05-13 18:07:02 +02:00
committed by GitHub
7 changed files with 330 additions and 29 deletions

View File

@ -6,8 +6,11 @@ import subprocess
import tempfile
import requests
import os
from typing import Optional
vars = None
num_shards: Optional[int] = None
current_shard = 0
#==================================================================#
# Decorator to prevent a function's actions from being run until
@ -135,7 +138,13 @@ def decodenewlines(txt):
return txt
#==================================================================#
# Downloads sharded huggingface checkpoints using aria2c if possible
# Returns number of layers given an HF model config
#==================================================================#
def num_layers(config):
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers
#==================================================================#
# Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers
@ -227,3 +236,12 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
json.dump({"url": u, "etag": t}, f)
#==================================================================#
# Given the path to a pytorch_model.bin.index.json, returns how many
# shards there are in the model
#==================================================================#
def get_num_shards(filename):
with open(filename) as f:
map_data = json.load(f)
return len(set(map_data["weight_map"].values()))