From a2ae87d1b730b575be8f52bd55a48f52a69dbf31 Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 15 Apr 2023 11:51:16 -0500 Subject: [PATCH] Utils: Support safetensors aria2 download --- utils.py | 61 +++++++++++++++++++++++++------------------------------- 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/utils.py b/utils.py index aa570668..90c514ca 100644 --- a/utils.py +++ b/utils.py @@ -193,23 +193,6 @@ def num_layers(config): from flask_socketio import emit def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None): - class Send_to_socketio(object): - def write(self, bar): - bar = bar.replace("\r", "").replace("\n", "") - - if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on - try: - print('\r' + bar, end='') - try: - socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1") - except: - pass - eventlet.sleep(seconds=0) - except: - pass - def flush(self): - pass - import transformers aria2_port = 6799 if koboldai_vars is None else koboldai_vars.aria2_port lengths = {} @@ -244,7 +227,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str = if k not in visited: lengths[k] = (v[1], v[1]) if bar is None: - bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) + bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=UIProgressBarFile()) koboldai_vars.status_message = "Download Model" koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values()) koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values()) @@ -280,11 +263,9 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") _cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE _revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION - sharded = False headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)} if use_auth_token: headers["authorization"] = f"Bearer {use_auth_token}" - storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model")) os.makedirs(storage_folder, exist_ok=True) @@ -294,25 +275,37 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa except ValueError: return False return True - while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file - try: - filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME - except AttributeError: - return - url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision) - if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): + + filename = None + + # NOTE: For now sharded Safetensors models are not supported. Haven't seen + # one of these out in the wild yet, probably due to how Safetensors has a + # lot of benifits of sharding built in + for possible_filename in [ + transformers.modeling_utils.SAFE_WEIGHTS_NAME, + transformers.modeling_utils.WEIGHTS_INDEX_NAME, + transformers.modeling_utils.WEIGHTS_NAME + ]: + # Try to get the huggingface.co URL of the model's weights file(s) + url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, possible_filename, revision=_revision) + + if is_cached(possible_filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): + filename = possible_filename break - if sharded: - return - else: - sharded = True - if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download - filenames = [transformers.modeling_utils.WEIGHTS_NAME] - else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it + + if not filename: + return + + if filename not in [transformers.modeling_utils.WEIGHTS_INDEX_NAME]: + # If the model isn't sharded, theres only one file to download + filenames = [filename] + else: + # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision) with open(map_filename) as f: map_data = json.load(f) filenames = set(map_data["weight_map"].values()) + urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames] if not force_download: urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]