commit
fe64e480ee
56
aiserver.py
56
aiserver.py
|
@ -1314,9 +1314,64 @@ def patch_causallm(model):
|
|||
Embedding._koboldai_patch_causallm_model = model
|
||||
return model
|
||||
|
||||
def patch_transformers_download():
|
||||
global transformers
|
||||
import copy, requests, tqdm, time
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
bar = bar.replace("\r", "")
|
||||
try:
|
||||
print(bar, end="\r")
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||
eventlet.sleep(seconds=0)
|
||||
except:
|
||||
pass
|
||||
def http_get(
|
||||
url: str,
|
||||
temp_file: transformers.utils.hub.BinaryIO,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
|
||||
file_name: transformers.utils.hub.Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
"""
|
||||
headers = copy.deepcopy(headers)
|
||||
if resume_size > 0:
|
||||
headers["Range"] = f"bytes={resume_size}-"
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
transformers.utils.hub._raise_for_status(r)
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
||||
if url[-11:] != 'config.json':
|
||||
progress = tqdm.tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
||||
file=Send_to_socketio(),
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
if url[-11:] != 'config.json':
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
if url[-11:] != 'config.json':
|
||||
progress.close()
|
||||
|
||||
transformers.utils.hub.http_get = http_get
|
||||
|
||||
|
||||
def patch_transformers():
|
||||
global transformers
|
||||
|
||||
patch_transformers_download()
|
||||
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
|
@ -6386,6 +6441,7 @@ if __name__ == "__main__":
|
|||
vars.flaskwebgui = True
|
||||
FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run()
|
||||
except:
|
||||
pass
|
||||
import webbrowser
|
||||
webbrowser.open_new('http://localhost:{0}'.format(port))
|
||||
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"
|
||||
|
|
12
utils.py
12
utils.py
|
@ -172,6 +172,16 @@ def num_layers(config):
|
|||
#==================================================================#
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
#==================================================================#
|
||||
from flask_socketio import emit
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
print("should be emitting: ", bar, end="")
|
||||
time.sleep(0.01)
|
||||
try:
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
|
@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
|||
done = True
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
|
|
Loading…
Reference in New Issue