Merge commit 'refs/pull/149/head' of https://github.com/ebolam/KoboldAI into UI2

This commit is contained in:
ebolam
2022-09-21 13:35:32 -04:00
4 changed files with 40 additions and 16 deletions

View File

@@ -178,16 +178,25 @@ def num_layers(config):
# Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
from flask_socketio import emit
class Send_to_socketio(object):
def write(self, bar):
time.sleep(0.01)
try:
print(bar)
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
except:
pass
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
class Send_to_socketio(object):
def write(self, bar):
bar = bar.replace("\r", "").replace("\n", "")
if bar != "":
try:
print(bar, end="\n")
try:
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
except:
pass
eventlet.sleep(seconds=0)
except:
pass
def flush(self):
pass
import transformers
lengths = {}
path = None
@@ -221,7 +230,7 @@ def _download_with_aria2(aria2_config: str, total_length: int, directory: str =
if k not in visited:
lengths[k] = (v[1], v[1])
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
koboldai_vars.total_download_chunks = sum(v[1] for v in lengths.values())
koboldai_vars.downloaded_chunks = sum(v[0] for v in lengths.values())
bar.n = koboldai_vars.downloaded_chunks
@@ -435,7 +444,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
for n in filenames:
prefix, suffix = n.rsplit("/", 1)
prefix, suffix = n.rsplit(os.sep, 1)
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
if os.path.exists(path):
os.remove(path)
@@ -443,10 +452,10 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit("/", 1)]).encode()
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit(os.sep, 1)]).encode()
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
for u, n in zip(urls, filenames):
prefix, suffix = n.rsplit("/", 1)
prefix, suffix = n.rsplit(os.sep, 1)
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):