Merge pull request #129 from VE-FORBRYDERNE/tqdm
Better model saving and better progress bars
This commit is contained in:
commit
24a2eb8c0b
108
aiserver.py
108
aiserver.py
|
@ -16,6 +16,9 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
|
|||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
from eventlet import tpool
|
||||
|
||||
import logging
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
from os import path, getcwd
|
||||
import time
|
||||
import re
|
||||
|
@ -54,6 +57,16 @@ if lupa.LUA_VERSION[:2] != (5, 4):
|
|||
print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr)
|
||||
|
||||
|
||||
# Make sure tqdm progress bars display properly in Colab
|
||||
from tqdm.auto import tqdm
|
||||
old_init = tqdm.__init__
|
||||
def new_init(self, *args, **kwargs):
|
||||
old_init(self, *args, **kwargs)
|
||||
if(self.ncols == 0 and kwargs.get("ncols") != 0):
|
||||
self.ncols = 99
|
||||
tqdm.__init__ = new_init
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Variables & Storage
|
||||
#==================================================================#
|
||||
|
@ -241,6 +254,7 @@ class vars:
|
|||
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
|
||||
badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]]
|
||||
badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]]
|
||||
fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
|
||||
deletewi = None # Temporary storage for UID to delete
|
||||
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
|
||||
widepth = 3 # How many historical actions to scan for WI hits
|
||||
|
@ -808,6 +822,7 @@ parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for
|
|||
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
|
||||
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
|
||||
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
|
||||
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
||||
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
||||
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
|
||||
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
|
||||
|
@ -867,6 +882,8 @@ if args.cpu:
|
|||
vars.smandelete = vars.host == args.override_delete
|
||||
vars.smanrename = vars.host == args.override_rename
|
||||
|
||||
vars.aria2_port = args.aria2_port or 6799
|
||||
|
||||
# Select a model to run
|
||||
if args.model:
|
||||
print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model)
|
||||
|
@ -1152,17 +1169,24 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
vars.fp32_model = False
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||
utils.from_pretrained_index_filename = None
|
||||
utils.from_pretrained_kwargs = kwargs
|
||||
utils.bar = None
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
# Lazy loader
|
||||
import torch_lazy_loader
|
||||
|
@ -1180,6 +1204,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
|
||||
|
||||
def lazy_load_callback(model_dict, f, **_):
|
||||
if lazy_load_callback.nested:
|
||||
return
|
||||
lazy_load_callback.nested = True
|
||||
|
||||
device_map = {}
|
||||
|
||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
||||
|
@ -1194,6 +1222,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
|
||||
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
|
||||
|
||||
if utils.num_shards is None or utils.current_shard == 0:
|
||||
if utils.num_shards is not None:
|
||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||
else:
|
||||
num_tensors = len(device_map)
|
||||
print(flush=True)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
|
@ -1201,7 +1237,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
current_offset = 0
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||
last_storage_key = storage_key
|
||||
|
@ -1218,6 +1254,8 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
|
||||
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
|
||||
if model_dict[key].dtype is torch.float32:
|
||||
vars.fp32_model = True
|
||||
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
|
||||
model_dict[key] = model_dict[key].to(torch.float16)
|
||||
if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16:
|
||||
|
@ -1225,10 +1263,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
model_dict[key] = model_dict[key].to(device)
|
||||
#print("OK", flush=True)
|
||||
current_offset += nbytes
|
||||
utils.bar.update(1)
|
||||
finally:
|
||||
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
lazy_load_callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
lazy_load_callback.nested = False
|
||||
return lazy_load_callback
|
||||
|
||||
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
|
||||
|
@ -1566,6 +1610,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
except Exception as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache", **lowmem)
|
||||
else:
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
def new_rebuild_tensor(storage, storage_offset, shape, stride):
|
||||
dtype = storage.storage_type.dtype
|
||||
if(not isinstance(dtype, torch.dtype)):
|
||||
dtype = storage.storage_type(0).dtype
|
||||
if(dtype is torch.float32 and len(shape) >= 2):
|
||||
vars.fp32_model = True
|
||||
return old_rebuild_tensor(storage, storage_offset, shape, stride)
|
||||
torch._utils._rebuild_tensor = new_rebuild_tensor
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
|
@ -1578,11 +1632,32 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
|||
except Exception as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache", **lowmem)
|
||||
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
|
||||
if not args.colab or args.savemodel:
|
||||
import shutil
|
||||
model = model.half()
|
||||
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')), max_shard_size="500MiB")
|
||||
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
if(vars.fp32_model): # Use save_pretrained to convert fp32 models to fp16
|
||||
model = model.half()
|
||||
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')), max_shard_size="500MiB")
|
||||
else: # For fp16 models, we can just copy the model files directly
|
||||
import transformers.configuration_utils
|
||||
import transformers.modeling_utils
|
||||
import transformers.file_utils
|
||||
# Save the config.json
|
||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
|
||||
if(utils.num_shards is None):
|
||||
# Save the pytorch_model.bin of an unsharded model
|
||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
|
||||
else:
|
||||
with open(utils.from_pretrained_index_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
# Save the pytorch_model.bin.index.json of a sharded model
|
||||
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
|
||||
# Then save the pytorch_model-#####-of-#####.bin files
|
||||
for filename in filenames:
|
||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
|
||||
shutil.rmtree("cache/")
|
||||
|
||||
if(vars.hascuda):
|
||||
|
@ -1622,17 +1697,24 @@ else:
|
|||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
vars.fp32_model = False
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||
utils.from_pretrained_index_filename = None
|
||||
utils.from_pretrained_kwargs = kwargs
|
||||
utils.bar = None
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
utils.from_pretrained_index_filename = index_filename
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
def tpumtjgetsofttokens():
|
||||
soft_tokens = None
|
||||
|
|
|
@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
import functools
|
||||
|
||||
def callback(model_dict, f, **_):
|
||||
if callback.nested:
|
||||
return
|
||||
callback.nested = True
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
|
@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
current_offset = 0
|
||||
if utils.current_shard == 0:
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
|
||||
if utils.num_shards is None or utils.current_shard == 0:
|
||||
if utils.num_shards is not None:
|
||||
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
|
||||
else:
|
||||
num_tensors = len(model_dict)
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
|
||||
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
|
||||
|
||||
# Some model weights are used by transformers but not by MTJ.
|
||||
# We have to materialize these weights anyways because
|
||||
|
@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
# tensors, which don't take up any actual CPU or TPU memory.
|
||||
if key not in model_spec:
|
||||
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
|
||||
utils.bar.update(1)
|
||||
continue
|
||||
|
||||
storage_key = model_dict[key].key
|
||||
|
@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
utils.bar.update(1)
|
||||
|
||||
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
|
||||
return
|
||||
|
||||
|
@ -1251,8 +1265,13 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
finally:
|
||||
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
|
||||
utils.bar.close()
|
||||
utils.bar = None
|
||||
callback.nested = False
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
callback.nested = False
|
||||
|
||||
if os.path.isdir(vars.model.replace('/', '_')):
|
||||
import shutil
|
||||
|
|
75
utils.py
75
utils.py
|
@ -5,12 +5,20 @@ import json
|
|||
import subprocess
|
||||
import tempfile
|
||||
import requests
|
||||
import requests.adapters
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
import os
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
vars = None
|
||||
num_shards: Optional[int] = None
|
||||
current_shard = 0
|
||||
from_pretrained_model_name = ""
|
||||
from_pretrained_index_filename: Optional[str] = None
|
||||
from_pretrained_kwargs = {}
|
||||
bar = None
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
|
@ -202,6 +210,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
if not urls:
|
||||
return
|
||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
||||
for n in filenames:
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||
|
@ -217,20 +226,53 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
path = os.path.join(_cache_dir, n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||
lengths = {}
|
||||
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
for line in p.stdout:
|
||||
print(line.decode(), end="", flush=True)
|
||||
path = f.name
|
||||
s = requests.Session()
|
||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||
bar = None
|
||||
done = False
|
||||
secret = os.urandom(17).hex()
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
while p.poll() is None:
|
||||
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||
if not r:
|
||||
s.close()
|
||||
if bar is not None:
|
||||
bar.n = bar.total
|
||||
bar.close()
|
||||
p.terminate()
|
||||
done = True
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
lengths[filename] = (int(x["completedLength"]), int(x["totalLength"]))
|
||||
visited.add(filename)
|
||||
for k, v in lengths.items():
|
||||
if k not in visited:
|
||||
lengths[k] = (v[1], v[1])
|
||||
bar.n = sum(v[0] for v in lengths.values())
|
||||
bar.update()
|
||||
time.sleep(0.1)
|
||||
path = f.name
|
||||
except Exception as e:
|
||||
p.terminate()
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
code = p.wait()
|
||||
if code:
|
||||
if not done and code:
|
||||
raise OSError(f"aria2 exited with exit code {code}")
|
||||
for u, t, n in zip(urls, etags, filenames):
|
||||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||
|
@ -245,3 +287,14 @@ def get_num_shards(filename):
|
|||
with open(filename) as f:
|
||||
map_data = json.load(f)
|
||||
return len(set(map_data["weight_map"].values()))
|
||||
|
||||
#==================================================================#
|
||||
# Given the name/path of a sharded model and the path to a
|
||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||
# sharded model. Requires lazy loader to be enabled to work properl
|
||||
#==================================================================#
|
||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers.modeling_utils
|
||||
import torch
|
||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||
|
|
Loading…
Reference in New Issue