Merge branch 'united' into mkultra

This commit is contained in:
vfbd
2022-07-27 11:35:32 -04:00
5 changed files with 82 additions and 6 deletions

View File

@@ -1044,6 +1044,7 @@ def general_startup(override_args=None):
parser.add_argument("--no_aria2", action='store_true', default=False, help="Prevents KoboldAI from using aria2 to download huggingface models more efficiently, in case aria2 is causing you issues")
parser.add_argument("--lowmem", action='store_true', help="Extra Low Memory loading for the GPU, slower but memory does not peak to twice the usage")
parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)")
parser.add_argument("--customsettings", help="Preloads arguements from json file. You only need to provide the location of the json file. Use customsettings.json template file. It can be renamed if you wish so that you can store multiple configurations. Leave any settings you want as default as null. Any values you wish to set need to be in double quotation marks")
#args: argparse.Namespace = None
if "pytest" in sys.modules and override_args is None:
args = parser.parse_args([])
@@ -1057,6 +1058,14 @@ def general_startup(override_args=None):
else:
args = parser.parse_args()
if args.customsettings:
f = open (args.customsettings)
importedsettings = json.load(f)
for items in importedsettings:
if importedsettings[items] is not None:
setattr(args, items, importedsettings[items])
f.close()
vars.model = args.model;
vars.revision = args.revision
@@ -1305,9 +1314,64 @@ def patch_causallm(model):
Embedding._koboldai_patch_causallm_model = model
return model
def patch_transformers_download():
global transformers
import copy, requests, tqdm, time
class Send_to_socketio(object):
def write(self, bar):
bar = bar.replace("\r", "")
try:
print(bar, end="\r")
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True)
eventlet.sleep(seconds=0)
except:
pass
def http_get(
url: str,
temp_file: transformers.utils.hub.BinaryIO,
proxies=None,
resume_size=0,
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
file_name: transformers.utils.hub.Optional[str] = None,
):
"""
Download remote file. Do not gobble up errors.
"""
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = f"bytes={resume_size}-"
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
transformers.utils.hub._raise_for_status(r)
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
if url[-11:] != 'config.json':
progress = tqdm.tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=total,
initial=resume_size,
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
file=Send_to_socketio(),
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
if url[-11:] != 'config.json':
progress.update(len(chunk))
temp_file.write(chunk)
if url[-11:] != 'config.json':
progress.close()
transformers.utils.hub.http_get = http_get
def patch_transformers():
global transformers
patch_transformers_download()
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
@@ -5782,7 +5846,7 @@ def importgame():
def importAidgRequest(id):
exitModes()
urlformat = "https://prompts.aidg.club/api/"
urlformat = "https://aetherroom.club/api/"
req = requests.get(urlformat+id)
if(req.status_code == 200):
@@ -6377,6 +6441,7 @@ if __name__ == "__main__":
vars.flaskwebgui = True
FlaskUI(app, socketio=socketio, start_server="flask-socketio", maximized=True, close_server_on_exit=True).run()
except:
pass
import webbrowser
webbrowser.open_new('http://localhost:{0}'.format(port))
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"

View File

@@ -0,0 +1 @@
{"aria2_port":null, "breakmodel":null, "breakmodel_disklayers":null, "breakmodel_gpulayers":null, "breakmodel_layers":null, "colab":null, "configname":null, "cpu":null, "host":null, "localtunnel":null, "lowmem":null, "model":null, "ngrok":null, "no_aria2":null, "noaimenu":null, "nobreakmodel":null, "override_delete":null, "override_rename":null, "path":null, "port":null, "quiet":null, "remote":null, "revision":null, "savemodel":null, "unblock":null}

View File

@@ -415,9 +415,9 @@ formatcontrols = [{
"tooltip": "Remove special characters (@,#,%,^, etc)"
},
{
"label": "Add sentence spacing",
"label": "Automatic spacing",
"id": "frmtadsnsp",
"tooltip": "If the last action ended with punctuation, add a space to the beginning of the next action."
"tooltip": "Add spaces automatically if needed"
},
{
"label": "Single Line",

View File

@@ -76,7 +76,7 @@
<div class="dropdown-menu">
<a class="dropdown-item" href="#" id="btn_import">AI Dungeon Adventure</a>
<a class="dropdown-item" href="#" id="btn_importwi">AI Dungeon World Info</a>
<a class="dropdown-item" href="#" id="btn_impaidg">aidg.club Prompt</a>
<a class="dropdown-item" href="#" id="btn_impaidg">aetherroom.club Prompt</a>
</div>
</li>
<li class="nav-item">
@@ -233,7 +233,7 @@
<div class="popuptitletext">Enter the Prompt Number</div>
</div>
<div class="aidgpopuplistheader">
(4-digit number at the end of aidg.club URL)
(4-digit number at the end of aetherroom.club URL)
</div>
<div class="aidgpopupcontent">
<input class="form-control" type="text" placeholder="Prompt Number" id="aidgpromptnum">

View File

@@ -172,6 +172,16 @@ def num_layers(config):
#==================================================================#
# Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
from flask_socketio import emit
class Send_to_socketio(object):
def write(self, bar):
print("should be emitting: ", bar, end="")
time.sleep(0.01)
try:
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True)
except:
pass
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers
import transformers.modeling_utils
@@ -268,7 +278,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
done = True
break
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
visited = set()
for x in r:
filename = x["files"][0]["path"]