diff --git a/aiserver.py b/aiserver.py index a289409c..74adef38 100644 --- a/aiserver.py +++ b/aiserver.py @@ -16,6 +16,9 @@ os.environ['EVENTLET_THREADPOOL_SIZE'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' from eventlet import tpool +import logging +logging.getLogger("urllib3").setLevel(logging.ERROR) + from os import path, getcwd import time import re @@ -808,6 +811,7 @@ parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel") parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service") parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable") +parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)") parser.add_argument("--model", help="Specify the Model Type to skip the Menu") parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)") parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)") @@ -867,6 +871,8 @@ if args.cpu: vars.smandelete = vars.host == args.override_delete vars.smanrename = vars.host == args.override_rename +vars.aria2_port = args.aria2_port or 6799 + # Select a model to run if args.model: print("Welcome to KoboldAI!\nYou have selected the following Model:", vars.model) diff --git a/utils.py b/utils.py index be064855..9565eaa4 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,9 @@ import json import subprocess import tempfile import requests +import requests.adapters +import time +from tqdm.auto import tqdm import os from typing import Optional @@ -202,6 +205,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d if not urls: return etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]] + headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls] filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)] for n in filenames: path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2") @@ -217,18 +221,49 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d path = os.path.join(_cache_dir, n) if os.path.exists(path): os.remove(path) + total_length = sum(int(h["Content-Length"]) for h in headers) + lengths = {} aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode() - with tempfile.NamedTemporaryFile("w+b", delete=False) as f: - f.write(aria2_config) - f.flush() - p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming", "false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - for line in p.stdout: - print(line.decode(), end="", flush=True) - path = f.name + s = requests.Session() + s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1))) + bar = None + secret = os.urandom(17).hex() try: - os.remove(path) - except OSError: - pass + with tempfile.NamedTemporaryFile("w+b", delete=False) as f: + f.write(aria2_config) + f.flush() + p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + while p.poll() is None: + r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"] + if not r: + s.close() + if bar is not None: + bar.n = bar.total + bar.close() + p.terminate() + break + if bar is None: + bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000) + visited = set() + for x in r: + filename = x["files"][0]["path"] + lengths[filename] = (int(x["completedLength"]), int(x["totalLength"])) + visited.add(filename) + for k, v in lengths.items(): + if k not in visited: + lengths[k] = (v[1], v[1]) + bar.n = sum(v[0] for v in lengths.values()) + bar.update() + time.sleep(0.1) + path = f.name + except Exception as e: + p.terminate() + raise e + finally: + try: + os.remove(path) + except OSError: + pass code = p.wait() if code: raise OSError(f"aria2 exited with exit code {code}")