mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Utils: Support safetensors aria2 download
This commit is contained in:
61
utils.py
61
utils.py
@@ -193,23 +193,6 @@ def num_layers(config):
|
||||
from flask_socketio import emit
|
||||
|
||||
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
bar = bar.replace("\r", "").replace("\n", "")
|
||||
|
||||
if bar != "" and [ord(num) for num in bar] != [27, 91, 65]: #No idea why we're getting the 27, 1, 65 character set, just killing to so we can move on
|
||||
try:
|
||||
print('\r' + bar, end='')
|
||||
try:
|
||||
socketio.emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
except:
|
||||
pass
|
||||
eventlet.sleep(seconds=0)
|
||||
except:
|
||||
pass
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
import transformers
|
||||
aria2_port = 6799 if koboldai_vars is None else koboldai_vars.aria2_port
|
||||
lengths = {}
|
||||
@@ -244,7 +227,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
|
||||
if k not in visited:
|
||||
lengths[k] = (v[1], v[1])
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=UIProgressBarFile())
|
||||
koboldai_vars.status_message = "Download Model"
|
||||
koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values())
|
||||
koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values())
|
||||
@@ -280,11 +263,9 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||
_revision = koboldai_vars.revision if koboldai_vars.revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
sharded = False
|
||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||
if use_auth_token:
|
||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||
|
||||
storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model"))
|
||||
os.makedirs(storage_folder, exist_ok=True)
|
||||
|
||||
@@ -294,25 +275,37 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||
try:
|
||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||
except AttributeError:
|
||||
return
|
||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=_revision)
|
||||
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||
|
||||
filename = None
|
||||
|
||||
# NOTE: For now sharded Safetensors models are not supported. Haven't seen
|
||||
# one of these out in the wild yet, probably due to how Safetensors has a
|
||||
# lot of benifits of sharding built in
|
||||
for possible_filename in [
|
||||
transformers.modeling_utils.SAFE_WEIGHTS_NAME,
|
||||
transformers.modeling_utils.WEIGHTS_INDEX_NAME,
|
||||
transformers.modeling_utils.WEIGHTS_NAME
|
||||
]:
|
||||
# Try to get the huggingface.co URL of the model's weights file(s)
|
||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, possible_filename, revision=_revision)
|
||||
|
||||
if is_cached(possible_filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||
filename = possible_filename
|
||||
break
|
||||
if sharded:
|
||||
return
|
||||
else:
|
||||
sharded = True
|
||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
|
||||
if not filename:
|
||||
return
|
||||
|
||||
if filename not in [transformers.modeling_utils.WEIGHTS_INDEX_NAME]:
|
||||
# If the model isn't sharded, theres only one file to download
|
||||
filenames = [filename]
|
||||
else:
|
||||
# Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
|
||||
with open(map_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
|
||||
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=_revision) for n in filenames]
|
||||
if not force_download:
|
||||
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
|
||||
|
Reference in New Issue
Block a user