Proper progress bar for aria2 downloads

This commit is contained in:
Gnome Ann
2022-05-13 17:00:10 -04:00
parent 7ea0c49c1a
commit 91d3672446
2 changed files with 51 additions and 10 deletions

View File

@ -16,6 +16,9 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from eventlet import tpool from eventlet import tpool
import logging
logging.getLogger("urllib3").setLevel(logging.ERROR)
from os import path, getcwd from os import path, getcwd
import time import time
import re import re
@ -808,6 +811,7 @@ parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel") parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service") parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable") parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
parser.add_argument("--model", help="Specify the Model Type to skip the Menu") parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)") parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)") parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
@ -867,6 +871,8 @@ if args.cpu:
vars.smandelete = vars.host == args.override_delete vars.smandelete = vars.host == args.override_delete
vars.smanrename = vars.host == args.override_rename vars.smanrename = vars.host == args.override_rename
vars.aria2_port = args.aria2_port or 6799
# Select a model to run # Select a model to run
if args.model: if args.model:
print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model) print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model)

View File

@ -5,6 +5,9 @@ import json
import subprocess import subprocess
import tempfile import tempfile
import requests import requests
import requests.adapters
import time
from tqdm.auto import tqdm
import os import os
from typing import Optional from typing import Optional
@ -202,6 +205,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
if not urls: if not urls:
return return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
for n in filenames: for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2") path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
@ -217,14 +221,45 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
path = os.path.join(_cache_dir, n) path = os.path.join(_cache_dir, n)
if os.path.exists(path): if os.path.exists(path):
os.remove(path) os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
lengths = {}
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode() aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None
secret = os.urandom(17).hex()
try:
with tempfile.NamedTemporaryFile("w+b", delete=False) as f: with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config) f.write(aria2_config)
f.flush() f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
for line in p.stdout: while p.poll() is None:
print(line.decode(), end="", flush=True) r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
if not r:
s.close()
if bar is not None:
bar.n = bar.total
bar.close()
p.terminate()
break
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
visited = set()
for x in r:
filename = x["files"][0]["path"]
lengths[filename] = (int(x["completedLength"]), int(x["totalLength"]))
visited.add(filename)
for k, v in lengths.items():
if k not in visited:
lengths[k] = (v[1], v[1])
bar.n = sum(v[0] for v in lengths.values())
bar.update()
time.sleep(0.1)
path = f.name path = f.name
except Exception as e:
p.terminate()
raise e
finally:
try: try:
os.remove(path) os.remove(path)
except OSError: except OSError: