mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' into mkultra
This commit is contained in:
67
aiserver.py
67
aiserver.py
@@ -1044,6 +1044,7 @@ def general_startup(override_args=None):
|
||||
parser.add_argument("--no_aria2", action='store_true', default=False, help="Prevents KoboldAI from using aria2 to download huggingface models more efficiently, in case aria2 is causing you issues")
|
||||
parser.add_argument("--lowmem", action='store_true', help="Extra Low Memory loading for the GPU, slower but memory does not peak to twice the usage")
|
||||
parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)")
|
||||
parser.add_argument("--customsettings", help="Preloads arguements from json file. You only need to provide the location of the json file. Use customsettings.json template file. It can be renamed if you wish so that you can store multiple configurations. Leave any settings you want as default as null. Any values you wish to set need to be in double quotation marks")
|
||||
#args: argparse.Namespace = None
|
||||
if "pytest" in sys.modules and override_args is None:
|
||||
args = parser.parse_args([])
|
||||
@@ -1057,6 +1058,14 @@ def general_startup(override_args=None):
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.customsettings:
|
||||
f = open (args.customsettings)
|
||||
importedsettings = json.load(f)
|
||||
for items in importedsettings:
|
||||
if importedsettings[items] is not None:
|
||||
setattr(args, items, importedsettings[items])
|
||||
f.close()
|
||||
|
||||
vars.model = args.model;
|
||||
vars.revision = args.revision
|
||||
|
||||
@@ -1305,9 +1314,64 @@ def patch_causallm(model):
|
||||
Embedding._koboldai_patch_causallm_model = model
|
||||
return model
|
||||
|
||||
def patch_transformers_download():
|
||||
global transformers
|
||||
import copy, requests, tqdm, time
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
bar = bar.replace("\r", "")
|
||||
try:
|
||||
print(bar, end="\r")
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||
eventlet.sleep(seconds=0)
|
||||
except:
|
||||
pass
|
||||
def http_get(
|
||||
url: str,
|
||||
temp_file: transformers.utils.hub.BinaryIO,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
|
||||
file_name: transformers.utils.hub.Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
"""
|
||||
headers = copy.deepcopy(headers)
|
||||
if resume_size > 0:
|
||||
headers["Range"] = f"bytes={resume_size}-"
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
transformers.utils.hub._raise_for_status(r)
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
||||
if url[-11:] != 'config.json':
|
||||
progress = tqdm.tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
||||
file=Send_to_socketio(),
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
if url[-11:] != 'config.json':
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
if url[-11:] != 'config.json':
|
||||
progress.close()
|
||||
|
||||
transformers.utils.hub.http_get = http_get
|
||||
|
||||
|
||||
def patch_transformers():
|
||||
global transformers
|
||||
|
||||
patch_transformers_download()
|
||||
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
@@ -5782,7 +5846,7 @@ def importgame():
|
||||
def importAidgRequest(id):
|
||||
exitModes()
|
||||
|
||||
urlformat = "https://prompts.aidg.club/api/"
|
||||
urlformat = "https://aetherroom.club/api/"
|
||||
req = requests.get(urlformat+id)
|
||||
|
||||
if(req.status_code == 200):
|
||||
@@ -6377,6 +6441,7 @@ if __name__ == "__main__":
|
||||
vars.flaskwebgui = True
|
||||
FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run()
|
||||
except:
|
||||
pass
|
||||
import webbrowser
|
||||
webbrowser.open_new('http://localhost:{0}'.format(port))
|
||||
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"
|
||||
|
1
customsettings_template.json
Normal file
1
customsettings_template.json
Normal file
@@ -0,0 +1 @@
|
||||
{"aria2_port":null, "breakmodel":null, "breakmodel_disklayers":null, "breakmodel_gpulayers":null, "breakmodel_layers":null, "colab":null, "configname":null, "cpu":null, "host":null, "localtunnel":null, "lowmem":null, "model":null, "ngrok":null, "no_aria2":null, "noaimenu":null, "nobreakmodel":null, "override_delete":null, "override_rename":null, "path":null, "port":null, "quiet":null, "remote":null, "revision":null, "savemodel":null, "unblock":null}
|
@@ -415,9 +415,9 @@ formatcontrols = [{
|
||||
"tooltip": "Remove special characters (@,#,%,^, etc)"
|
||||
},
|
||||
{
|
||||
"label": "Add sentence spacing",
|
||||
"label": "Automatic spacing",
|
||||
"id": "frmtadsnsp",
|
||||
"tooltip": "If the last action ended with punctuation, add a space to the beginning of the next action."
|
||||
"tooltip": "Add spaces automatically if needed"
|
||||
},
|
||||
{
|
||||
"label": "Single Line",
|
||||
|
@@ -76,7 +76,7 @@
|
||||
<div class="dropdown-menu">
|
||||
<a class="dropdown-item" href="#" id="btn_import">AI Dungeon Adventure</a>
|
||||
<a class="dropdown-item" href="#" id="btn_importwi">AI Dungeon World Info</a>
|
||||
<a class="dropdown-item" href="#" id="btn_impaidg">aidg.club Prompt</a>
|
||||
<a class="dropdown-item" href="#" id="btn_impaidg">aetherroom.club Prompt</a>
|
||||
</div>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
@@ -233,7 +233,7 @@
|
||||
<div class="popuptitletext">Enter the Prompt Number</div>
|
||||
</div>
|
||||
<div class="aidgpopuplistheader">
|
||||
(4-digit number at the end of aidg.club URL)
|
||||
(4-digit number at the end of aetherroom.club URL)
|
||||
</div>
|
||||
<div class="aidgpopupcontent">
|
||||
<input class="form-control" type="text" placeholder="Prompt Number" id="aidgpromptnum">
|
||||
|
12
utils.py
12
utils.py
@@ -172,6 +172,16 @@ def num_layers(config):
|
||||
#==================================================================#
|
||||
# Downloads huggingface checkpoints using aria2c if possible
|
||||
#==================================================================#
|
||||
from flask_socketio import emit
|
||||
class Send_to_socketio(object):
|
||||
def write(self, bar):
|
||||
print("should be emitting: ", bar, end="")
|
||||
time.sleep(0.01)
|
||||
try:
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
@@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
done = True
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
|
Reference in New Issue
Block a user