mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2024-12-12 08:36:28 +01:00
aria2_hook now handles properly when vars is None
This commit is contained in:
parent
bae8d88651
commit
a51e4f0651
5
utils.py
5
utils.py
@ -205,6 +205,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
|
aria2_port = 6799 if vars is None else vars.aria2_port
|
||||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||||
sharded = False
|
sharded = False
|
||||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||||
@ -269,9 +270,9 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||||
f.write(aria2_config)
|
f.write(aria2_config)
|
||||||
f.flush()
|
f.flush()
|
||||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
while p.poll() is None:
|
while p.poll() is None:
|
||||||
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
r = s.post(f"http://localhost:{aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||||
if not r:
|
if not r:
|
||||||
s.close()
|
s.close()
|
||||||
if bar is not None:
|
if bar is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user