Utils: Support safetensors aria2 download

This commit is contained in:
somebody
2023-04-15 11:51:16 -05:00
parent 2b950f08d3
commit a2ae87d1b7

View File

@@ -193,23 +193,6 @@ def num_layers(config):
from flask_socketio import emit from flask_socketio import emit
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None): def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
class Send_to_socketio(object):
def write(self, bar):
bar = bar.replace("\r", "").replace("\n", "")
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
try:
print('\r' + bar, end='')
try:
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
except:
pass
eventlet.sleep(seconds=0)
except:
pass
def flush(self):
pass
import transformers import transformers
aria2_port = 6799 if koboldai_vars is None else koboldai_vars.aria2_port aria2_port = 6799 if koboldai_vars is None else koboldai_vars.aria2_port
lengths = {} lengths = {}
@@ -244,7 +227,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
if k not in visited: if k not in visited:
lengths[k] = (v[1], v[1]) lengths[k] = (v[1], v[1])
if bar is None: if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio()) bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=UIProgressBarFile())
koboldai_vars.status_message = "Download Model" koboldai_vars.status_message = "Download Model"
koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values()) koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values())
koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values()) koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values())
@@ -280,11 +263,9 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE _cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
_revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION _revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)} headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token: if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}" headers["authorization"] = f"Bearer {use_auth_token}"
storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model")) storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model"))
os.makedirs(storage_folder, exist_ok=True) os.makedirs(storage_folder, exist_ok=True)
@@ -294,25 +275,37 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
except ValueError: except ValueError:
return False return False
return True return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try: filename = None
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError: # NOTE: For now sharded Safetensors models are not supported. Haven't seen
return # one of these out in the wild yet, probably due to how Safetensors has a
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision) # lot of benifits of sharding built in
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers): for possible_filename in [
transformers.modeling_utils.SAFE_WEIGHTS_NAME,
transformers.modeling_utils.WEIGHTS_INDEX_NAME,
transformers.modeling_utils.WEIGHTS_NAME
]:
# Try to get the huggingface.co URL of the model's weights file(s)
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, possible_filename, revision=_revision)
if is_cached(possible_filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
filename = possible_filename
break break
if sharded:
if not filename:
return return
if filename not in [transformers.modeling_utils.WEIGHTS_INDEX_NAME]:
# If the model isn't sharded, theres only one file to download
filenames = [filename]
else: else:
sharded = True # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision) map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
with open(map_filename) as f: with open(map_filename) as f:
map_data = json.load(f) map_data = json.load(f)
filenames = set(map_data["weight_map"].values()) filenames = set(map_data["weight_map"].values())
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames] urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
if not force_download: if not force_download:
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)] urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]