mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-25 16:08:00 +01:00
Aria2 Fixes
This commit is contained in:
parent
f62c740f7e
commit
cca3ce3493
@ -1,10 +0,0 @@
|
|||||||
import patch_torch_save
|
|
||||||
from transformers import AutoModel
|
|
||||||
|
|
||||||
def kaiad(): # put arbitrary code in here
|
|
||||||
print("This model was provided for free by KoboldAI. Check out our free interface at KoboldAI.org")
|
|
||||||
|
|
||||||
patched_save_function = patch_torch_save.patch_save_function(kaiad)
|
|
||||||
|
|
||||||
model = AutoModel.from_pretrained("facebook/opt-125m")
|
|
||||||
model.save_pretrained("./local_folder", save_function=patched_save_function) # optionally, upload to HF hub
|
|
8
utils.py
8
utils.py
@ -182,7 +182,7 @@ class Send_to_socketio(object):
|
|||||||
def write(self, bar):
|
def write(self, bar):
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
try:
|
try:
|
||||||
print(bar)
|
print(bar, end = "\r")
|
||||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -429,7 +429,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||||
|
|
||||||
for n in filenames:
|
for n in filenames:
|
||||||
prefix, suffix = n.rsplit("/", 1)
|
prefix, suffix = n.rsplit(os.sep, 1)
|
||||||
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
|
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
@ -437,10 +437,10 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||||
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit("/", 1)]).encode()
|
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit(os.sep, 1)]).encode()
|
||||||
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
|
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
|
||||||
for u, n in zip(urls, filenames):
|
for u, n in zip(urls, filenames):
|
||||||
prefix, suffix = n.rsplit("/", 1)
|
prefix, suffix = n.rsplit(os.sep, 1)
|
||||||
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
|
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
|
||||||
|
|
||||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user