Proper progress bar for aria2 downloads
This commit is contained in:
parent
7ea0c49c1a
commit
91d3672446
|
@ -16,6 +16,9 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
|
|||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
from eventlet import tpool
|
||||
|
||||
import logging
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
from os import path, getcwd
|
||||
import time
|
||||
import re
|
||||
|
@ -808,6 +811,7 @@ parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for
|
|||
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
|
||||
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
|
||||
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
|
||||
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
|
||||
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
|
||||
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
|
||||
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
|
||||
|
@ -867,6 +871,8 @@ if args.cpu:
|
|||
vars.smandelete = vars.host == args.override_delete
|
||||
vars.smanrename = vars.host == args.override_rename
|
||||
|
||||
vars.aria2_port = args.aria2_port or 6799
|
||||
|
||||
# Select a model to run
|
||||
if args.model:
|
||||
print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model)
|
||||
|
|
55
utils.py
55
utils.py
|
@ -5,6 +5,9 @@ import json
|
|||
import subprocess
|
||||
import tempfile
|
||||
import requests
|
||||
import requests.adapters
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
@ -202,6 +205,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
if not urls:
|
||||
return
|
||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
||||
for n in filenames:
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||
|
@ -217,18 +221,49 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
path = os.path.join(_cache_dir, n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||
lengths = {}
|
||||
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
for line in p.stdout:
|
||||
print(line.decode(), end="", flush=True)
|
||||
path = f.name
|
||||
s = requests.Session()
|
||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||
bar = None
|
||||
secret = os.urandom(17).hex()
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
while p.poll() is None:
|
||||
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||
if not r:
|
||||
s.close()
|
||||
if bar is not None:
|
||||
bar.n = bar.total
|
||||
bar.close()
|
||||
p.terminate()
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
lengths[filename] = (int(x["completedLength"]), int(x["totalLength"]))
|
||||
visited.add(filename)
|
||||
for k, v in lengths.items():
|
||||
if k not in visited:
|
||||
lengths[k] = (v[1], v[1])
|
||||
bar.n = sum(v[0] for v in lengths.values())
|
||||
bar.update()
|
||||
time.sleep(0.1)
|
||||
path = f.name
|
||||
except Exception as e:
|
||||
p.terminate()
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
code = p.wait()
|
||||
if code:
|
||||
raise OSError(f"aria2 exited with exit code {code}")
|
||||
|
|
Loading…
Reference in New Issue