diff --git a/colabkobold.sh b/colabkobold.sh index 2aa5daf9..4c41675a 100644 --- a/colabkobold.sh +++ b/colabkobold.sh @@ -186,7 +186,7 @@ fi #Download routine for Aria2c scripts if [ ! -z ${aria2+x} ]; then - curl -L $aria2 | aria2c -c -i- -d$dloc --user-agent=KoboldAI --file-allocation=none + curl -L $aria2 | aria2c -x 10 -s 10 -j 10 -c -i- -d$dloc --user-agent=KoboldAI --file-allocation=none fi #Extract the model with 7z diff --git a/utils.py b/utils.py index a2a5fef2..0eb1de28 100644 --- a/utils.py +++ b/utils.py @@ -183,13 +183,13 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d sharded = True else: break - if not sharded: # If the model has a pytorch_model.bin file, that's the only large file to download so it's probably more efficient to just let transformers download it - return - # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it - map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent) - with open(map_filename) as f: - map_data = json.load(f) - filenames = set(map_data["weight_map"].values()) + if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download + filenames = [transformers.modeling_utils.WEIGHTS_NAME] + else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it + map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent) + with open(map_filename) as f: + map_data = json.load(f) + filenames = set(map_data["weight_map"].values()) urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames] def is_cached(url): try: @@ -210,7 +210,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d aria2_config = "\n".join(f"{u}\n out={n}" for u, n in zip(urls, filenames)).encode() with tempfile.NamedTemporaryFile("w+b", delete=True) as f: f.write(aria2_config) - p = subprocess.Popen(["aria2c", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) for line in p.stdout: print(line.decode(), end="", flush=True) for u, t, n in zip(urls, etags, filenames):