Merge pull request #129 from VE-FORBRYDERNE/tqdm

Better model saving and better progress bars
This commit is contained in:
henk717
2022-05-14 18:02:41 +02:00
committed by GitHub
3 changed files with 179 additions and 25 deletions

View File

@@ -16,6 +16,9 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from eventlet import tpool from eventlet import tpool
import logging
logging.getLogger("urllib3").setLevel(logging.ERROR)
from os import path, getcwd from os import path, getcwd
import time import time
import re import re
@@ -54,6 +57,16 @@ if lupa.LUA_VERSION[:2] != (5, 4):
print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr) print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr)
# Make sure tqdm progress bars display properly in Colab
from tqdm.auto import tqdm
old_init = tqdm.__init__
def new_init(self, *args, **kwargs):
old_init(self, *args, **kwargs)
if(self.ncols == 0 and kwargs.get("ncols") != 0):
self.ncols = 99
tqdm.__init__ = new_init
#==================================================================# #==================================================================#
# Variables & Storage # Variables & Storage
#==================================================================# #==================================================================#
@@ -241,6 +254,7 @@ class vars:
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]] badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]]
badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]] badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]]
fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
deletewi = None # Temporary storage for UID to delete deletewi = None # Temporary storage for UID to delete
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
widepth = 3 # How many historical actions to scan for WI hits widepth = 3 # How many historical actions to scan for WI hits
@@ -808,6 +822,7 @@ parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel") parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service") parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable") parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
parser.add_argument("--model", help="Specify the Model Type to skip the Menu") parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)") parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)") parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
@@ -867,6 +882,8 @@ if args.cpu:
vars.smandelete = vars.host == args.override_delete vars.smandelete = vars.host == args.override_delete
vars.smanrename = vars.host == args.override_rename vars.smanrename = vars.host == args.override_rename
vars.aria2_port = args.aria2_port or 6799
# Select a model to run # Select a model to run
if args.model: if args.model:
print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model) print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model)
@@ -1152,15 +1169,22 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
old_from_pretrained = PreTrainedModel.from_pretrained.__func__ old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
vars.fp32_model = False
utils.num_shards = None utils.num_shards = None
utils.current_shard = 0 utils.current_shard = 0
utils.from_pretrained_model_name = pretrained_model_name_or_path
utils.from_pretrained_index_filename = None
utils.from_pretrained_kwargs = kwargs
utils.bar = None
if not args.no_aria2: if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename) utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
@@ -1180,6 +1204,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
ram_blocks = gpu_blocks = cumulative_gpu_blocks = None ram_blocks = gpu_blocks = cumulative_gpu_blocks = None
def lazy_load_callback(model_dict, f, **_): def lazy_load_callback(model_dict, f, **_):
if lazy_load_callback.nested:
return
lazy_load_callback.nested = True
device_map = {} device_map = {}
for _key, spec in lazy_load_spec.get("layer_weights", {}).items(): for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
@@ -1194,6 +1222,14 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map: if isinstance(value, torch_lazy_loader.LazyTensor) and key not in device_map:
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu"
if utils.num_shards is None or utils.current_shard == 0:
if utils.num_shards is not None:
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
else:
num_tensors = len(device_map)
print(flush=True)
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
with zipfile.ZipFile(f, "r") as z: with zipfile.ZipFile(f, "r") as z:
try: try:
last_storage_key = None last_storage_key = None
@@ -1201,7 +1237,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
current_offset = 0 current_offset = 0
if utils.num_shards is not None: if utils.num_shards is not None:
utils.current_shard += 1 utils.current_shard += 1
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
storage_key = model_dict[key].key storage_key = model_dict[key].key
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset: if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
last_storage_key = storage_key last_storage_key = storage_key
@@ -1218,6 +1254,8 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
#print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True)
model_dict[key] = model_dict[key].materialize(f, map_location="cpu") model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
if model_dict[key].dtype is torch.float32:
vars.fp32_model = True
if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32: if convert_to_float16 and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
model_dict[key] = model_dict[key].to(torch.float16) model_dict[key] = model_dict[key].to(torch.float16)
if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16: if not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16:
@@ -1225,10 +1263,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
model_dict[key] = model_dict[key].to(device) model_dict[key] = model_dict[key].to(device)
#print("OK", flush=True) #print("OK", flush=True)
current_offset += nbytes current_offset += nbytes
utils.bar.update(1)
finally: finally:
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
utils.bar.close()
utils.bar = None
lazy_load_callback.nested = False
if isinstance(f, zipfile.ZipExtFile): if isinstance(f, zipfile.ZipExtFile):
f.close() f.close()
lazy_load_callback.nested = False
return lazy_load_callback return lazy_load_callback
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json") lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
@@ -1566,6 +1610,16 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
except Exception as e: except Exception as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache", **lowmem) model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache", **lowmem)
else: else:
old_rebuild_tensor = torch._utils._rebuild_tensor
def new_rebuild_tensor(storage, storage_offset, shape, stride):
dtype = storage.storage_type.dtype
if(not isinstance(dtype, torch.dtype)):
dtype = storage.storage_type(0).dtype
if(dtype is torch.float32 and len(shape) >= 2):
vars.fp32_model = True
return old_rebuild_tensor(storage, storage_offset, shape, stride)
torch._utils._rebuild_tensor = new_rebuild_tensor
try: try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
except Exception as e: except Exception as e:
@@ -1578,11 +1632,32 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
except Exception as e: except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache", **lowmem) model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache", **lowmem)
torch._utils._rebuild_tensor = old_rebuild_tensor
if not args.colab or args.savemodel: if not args.colab or args.savemodel:
import shutil import shutil
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
if(vars.fp32_model): # Use save_pretrained to convert fp32 models to fp16
model = model.half() model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')), max_shard_size="500MiB") model.save_pretrained("models/{}".format(vars.model.replace('/', '_')), max_shard_size="500MiB")
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) else: # For fp16 models, we can just copy the model files directly
import transformers.configuration_utils
import transformers.modeling_utils
import transformers.file_utils
# Save the config.json
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
if(utils.num_shards is None):
# Save the pytorch_model.bin of an unsharded model
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
else:
with open(utils.from_pretrained_index_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
# Save the pytorch_model.bin.index.json of a sharded model
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
# Then save the pytorch_model-#####-of-#####.bin files
for filename in filenames:
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(vars.model, filename, revision=vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
shutil.rmtree("cache/") shutil.rmtree("cache/")
if(vars.hascuda): if(vars.hascuda):
@@ -1622,15 +1697,22 @@ else:
old_from_pretrained = PreTrainedModel.from_pretrained.__func__ old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod @classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
vars.fp32_model = False
utils.num_shards = None utils.num_shards = None
utils.current_shard = 0 utils.current_shard = 0
utils.from_pretrained_model_name = pretrained_model_name_or_path
utils.from_pretrained_index_filename = None
utils.from_pretrained_kwargs = kwargs
utils.bar = None
if not args.no_aria2: if not args.no_aria2:
utils.aria2_hook(pretrained_model_name_or_path, **kwargs) utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
PreTrainedModel.from_pretrained = new_from_pretrained PreTrainedModel.from_pretrained = new_from_pretrained
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs): def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
utils.num_shards = utils.get_num_shards(index_filename) utils.num_shards = utils.get_num_shards(index_filename)
utils.from_pretrained_index_filename = index_filename
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files

View File

@@ -1160,6 +1160,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
import functools import functools
def callback(model_dict, f, **_): def callback(model_dict, f, **_):
if callback.nested:
return
callback.nested = True
with zipfile.ZipFile(f, "r") as z: with zipfile.ZipFile(f, "r") as z:
try: try:
last_storage_key = None last_storage_key = None
@@ -1167,9 +1170,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
current_offset = 0 current_offset = 0
if utils.current_shard == 0: if utils.current_shard == 0:
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n") print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
if utils.num_shards is None or utils.current_shard == 0:
if utils.num_shards is not None:
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
else:
num_tensors = len(model_dict)
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
if utils.num_shards is not None: if utils.num_shards is not None:
utils.current_shard += 1 utils.current_shard += 1
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")): for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
# Some model weights are used by transformers but not by MTJ. # Some model weights are used by transformers but not by MTJ.
# We have to materialize these weights anyways because # We have to materialize these weights anyways because
@@ -1178,6 +1189,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
# tensors, which don't take up any actual CPU or TPU memory. # tensors, which don't take up any actual CPU or TPU memory.
if key not in model_spec: if key not in model_spec:
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta") model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
utils.bar.update(1)
continue continue
storage_key = model_dict[key].key storage_key = model_dict[key].key
@@ -1230,6 +1242,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
np.empty(params["cores_per_replica"]), np.empty(params["cores_per_replica"]),
) )
utils.bar.update(1)
if utils.num_shards is not None and utils.current_shard < utils.num_shards: if utils.num_shards is not None and utils.current_shard < utils.num_shards:
return return
@@ -1251,8 +1265,13 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
print("\n\nERROR: " + error, file=sys.stderr) print("\n\nERROR: " + error, file=sys.stderr)
raise RuntimeError(error) raise RuntimeError(error)
finally: finally:
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
utils.bar.close()
utils.bar = None
callback.nested = False
if isinstance(f, zipfile.ZipExtFile): if isinstance(f, zipfile.ZipExtFile):
f.close() f.close()
callback.nested = False
if os.path.isdir(vars.model.replace('/', '_')): if os.path.isdir(vars.model.replace('/', '_')):
import shutil import shutil

View File

@@ -5,12 +5,20 @@ import json
import subprocess import subprocess
import tempfile import tempfile
import requests import requests
import requests.adapters
import time
from tqdm.auto import tqdm
import os import os
import itertools
from typing import Optional from typing import Optional
vars = None vars = None
num_shards: Optional[int] = None num_shards: Optional[int] = None
current_shard = 0 current_shard = 0
from_pretrained_model_name = ""
from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {}
bar = None
#==================================================================# #==================================================================#
# Decorator to prevent a function's actions from being run until # Decorator to prevent a function's actions from being run until
@@ -202,6 +210,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
if not urls: if not urls:
return return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
for n in filenames: for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2") path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
@@ -217,20 +226,53 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
path = os.path.join(_cache_dir, n) path = os.path.join(_cache_dir, n)
if os.path.exists(path): if os.path.exists(path):
os.remove(path) os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
lengths = {}
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode() aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None
done = False
secret = os.urandom(17).hex()
try:
with tempfile.NamedTemporaryFile("w+b", delete=False) as f: with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config) f.write(aria2_config)
f.flush() f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
for line in p.stdout: while p.poll() is None:
print(line.decode(), end="", flush=True) r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
if not r:
s.close()
if bar is not None:
bar.n = bar.total
bar.close()
p.terminate()
done = True
break
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
visited = set()
for x in r:
filename = x["files"][0]["path"]
lengths[filename] = (int(x["completedLength"]), int(x["totalLength"]))
visited.add(filename)
for k, v in lengths.items():
if k not in visited:
lengths[k] = (v[1], v[1])
bar.n = sum(v[0] for v in lengths.values())
bar.update()
time.sleep(0.1)
path = f.name path = f.name
except Exception as e:
p.terminate()
raise e
finally:
try: try:
os.remove(path) os.remove(path)
except OSError: except OSError:
pass pass
code = p.wait() code = p.wait()
if code: if not done and code:
raise OSError(f"aria2 exited with exit code {code}") raise OSError(f"aria2 exited with exit code {code}")
for u, t, n in zip(urls, etags, filenames): for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n)) os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
@@ -245,3 +287,14 @@ def get_num_shards(filename):
with open(filename) as f: with open(filename) as f:
map_data = json.load(f) map_data = json.load(f)
return len(set(map_data["weight_map"].values())) return len(set(map_data["weight_map"].values()))
#==================================================================#
# Given the name/path of a sharded model and the path to a
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers.modeling_utils
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))