mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' of https://github.com/henk717/KoboldAI into accelerate-offloading
This commit is contained in:
115
aiserver.py
115
aiserver.py
@@ -13,7 +13,7 @@ import shutil
|
||||
import eventlet
|
||||
|
||||
eventlet.monkey_patch(all=True, thread=False, os=False)
|
||||
import os, inspect
|
||||
import os, inspect, contextlib, pickle
|
||||
os.system("")
|
||||
__file__ = os.path.dirname(os.path.realpath(__file__))
|
||||
os.chdir(__file__)
|
||||
@@ -285,6 +285,7 @@ model_menu = {
|
||||
],
|
||||
'nsfwlist': [
|
||||
MenuModel("Erebus 20B (NSFW)", "KoboldAI/GPT-NeoX-20B-Erebus", "64GB"),
|
||||
MenuModel("Nerybus 13B (NSFW)", "KoboldAI/OPT-13B-Nerybus-Mix", "32GB"),
|
||||
MenuModel("Erebus 13B (NSFW)", "KoboldAI/OPT-13B-Erebus", "32GB"),
|
||||
MenuModel("Shinen FSD 13B (NSFW)", "KoboldAI/fairseq-dense-13B-Shinen", "32GB"),
|
||||
MenuModel("Erebus 6.7B (NSFW)", "KoboldAI/OPT-6.7B-Erebus", "16GB"),
|
||||
@@ -629,14 +630,20 @@ from modeling.patches import patch_transformers
|
||||
import importlib
|
||||
model_backend_code = {}
|
||||
model_backends = {}
|
||||
model_backend_type_crosswalk = {}
|
||||
for module in os.listdir("./modeling/inference_models"):
|
||||
if not os.path.isfile(os.path.join("./modeling/inference_models",module)) and module != '__pycache__':
|
||||
try:
|
||||
model_backend_code[module] = importlib.import_module('modeling.inference_models.{}.class'.format(module))
|
||||
model_backends[model_backend_code[module].model_backend_name] = model_backend_code[module].model_backend()
|
||||
if 'disable' in vars(model_backends[model_backend_code[module].model_backend_name]):
|
||||
if model_backends[model_backend_code[module].model_backend_name].disable:
|
||||
del model_backends[model_backend_code[module].model_backend_name]
|
||||
if 'disable' in vars(model_backends[model_backend_code[module].model_backend_name]) and model_backends[model_backend_code[module].model_backend_name].disable:
|
||||
del model_backends[model_backend_code[module].model_backend_name]
|
||||
else:
|
||||
if model_backend_code[module].model_backend_type in model_backend_type_crosswalk:
|
||||
model_backend_type_crosswalk[model_backend_code[module].model_backend_type].append(model_backend_code[module].model_backend_name)
|
||||
else:
|
||||
model_backend_type_crosswalk[model_backend_code[module].model_backend_type] = [model_backend_code[module].model_backend_name]
|
||||
|
||||
except Exception:
|
||||
logger.error("Model Backend {} failed to load".format(module))
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -1392,9 +1399,7 @@ def general_startup(override_args=None):
|
||||
parser.add_argument("--summarizer_model", action='store', default="philschmid/bart-large-cnn-samsum", help="Huggingface model to use for summarization. Defaults to sshleifer/distilbart-cnn-12-6")
|
||||
parser.add_argument("--max_summary_length", action='store', default=75, help="Maximum size for summary to send to image generation")
|
||||
parser.add_argument("--multi_story", action='store_true', default=False, help="Allow multi-story mode (experimental)")
|
||||
parser.add_argument("--peft", type=str, help="Specify the path or HuggingFace ID of a Peft to load it. Not supported on TPU. (Experimental)")
|
||||
parser.add_argument("--trust_remote_code", action='store_true', default=False, help="Allow Huggingface Models to Execute Code (Insecure!)")
|
||||
|
||||
parser.add_argument("--peft", type=str, help="Specify the path or HuggingFace ID of a Peft to load it. Not supported on TPU. (Experimental)")
|
||||
parser.add_argument('-f', action='store', help="option for compatability with colab memory profiles")
|
||||
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
|
||||
parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
|
||||
@@ -1476,7 +1481,6 @@ def general_startup(override_args=None):
|
||||
args.remote = True;
|
||||
args.override_rename = True;
|
||||
args.override_delete = True;
|
||||
args.nobreakmodel = True;
|
||||
args.quiet = True;
|
||||
args.lowmem = True;
|
||||
args.noaimenu = True;
|
||||
@@ -1485,7 +1489,8 @@ def general_startup(override_args=None):
|
||||
koboldai_vars.quiet = True
|
||||
|
||||
if args.nobreakmodel:
|
||||
model_backends['Huggingface'].nobreakmodel = True
|
||||
for model_backend in model_backends:
|
||||
model_backends[model_backend].nobreakmodel = True
|
||||
|
||||
if args.remote:
|
||||
koboldai_vars.host = True;
|
||||
@@ -1522,13 +1527,6 @@ def general_startup(override_args=None):
|
||||
allowed_ips = sorted(allowed_ips, key=lambda ip: int(''.join([i.zfill(3) for i in ip.split('.')])))
|
||||
print(f"Allowed IPs: {allowed_ips}")
|
||||
|
||||
if args.trust_remote_code:
|
||||
logger.warning("EXECUTION OF UNSAFE REMOTE CODE IS ENABLED!!!")
|
||||
logger.warning("You are not protected from Model Viruses in this mode!")
|
||||
logger.warning("Exit the program now to abort execution!")
|
||||
logger.warning("Only use this mode with models that you trust and verified!")
|
||||
time.sleep(25)
|
||||
koboldai_vars.trust_remote_code = True
|
||||
if args.cpu:
|
||||
koboldai_vars.use_colab_tpu = False
|
||||
koboldai_vars.hascuda = False
|
||||
@@ -1639,7 +1637,75 @@ def unload_model():
|
||||
|
||||
#Reload our badwords
|
||||
koboldai_vars.badwordsids = koboldai_settings.badwordsids_default
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def original_persistent_load(self, saved_id):
|
||||
return super().persistent_load(saved_id)
|
||||
|
||||
def forced_persistent_load(self, saved_id):
|
||||
if saved_id[0] != "storage":
|
||||
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
||||
return self.original_persistent_load(saved_id)
|
||||
|
||||
def find_class(self, module, name):
|
||||
if module == "collections" and name == "OrderedDict":
|
||||
return collections.OrderedDict
|
||||
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
||||
return torch._utils._rebuild_tensor_v2
|
||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
||||
return torch._tensor._rebuild_from_type_v2
|
||||
elif module == "torch" and name in (
|
||||
"DoubleStorage",
|
||||
"FloatStorage",
|
||||
"HalfStorage",
|
||||
"LongStorage",
|
||||
"IntStorage",
|
||||
"ShortStorage",
|
||||
"CharStorage",
|
||||
"ByteStorage",
|
||||
"BoolStorage",
|
||||
"BFloat16Storage",
|
||||
"Tensor",
|
||||
):
|
||||
return getattr(torch, name)
|
||||
elif module == "numpy.core.multiarray" and name == "scalar":
|
||||
return np.core.multiarray.scalar
|
||||
elif module == "numpy" and name == "dtype":
|
||||
return np.dtype
|
||||
elif module == "_codecs" and name == "encode":
|
||||
return _codecs.encode
|
||||
else:
|
||||
# Forbid everything else.
|
||||
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
||||
raise pickle.UnpicklingError(
|
||||
f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code. If you think this is incorrect ask the developer to unban the ability for {module} to execute {name}"
|
||||
)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
self.original_persistent_load = getattr(
|
||||
self, "persistent_load", pickle.Unpickler.persistent_load
|
||||
)
|
||||
self.persistent_load = self.forced_persistent_load
|
||||
return super().load(*args, **kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||
try:
|
||||
old_unpickler = pickle.Unpickler
|
||||
pickle.Unpickler = unpickler
|
||||
|
||||
old_pickle_load = pickle.load
|
||||
|
||||
def new_pickle_load(*args, **kwargs):
|
||||
return pickle.Unpickler(*args, **kwargs).load()
|
||||
|
||||
pickle.load = new_pickle_load
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
pickle.Unpickler = old_unpickler
|
||||
pickle.load = old_pickle_load
|
||||
|
||||
def load_model(model_backend, initial_load=False):
|
||||
global model
|
||||
@@ -1687,8 +1753,10 @@ def load_model(model_backend, initial_load=False):
|
||||
koboldai_vars.default_preset = koboldai_settings.default_preset
|
||||
|
||||
|
||||
model = model_backends[model_backend]
|
||||
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
|
||||
|
||||
with use_custom_unpickler(RestrictedUnpickler):
|
||||
model = model_backends[model_backend]
|
||||
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
|
||||
koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
|
||||
if koboldai_vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
koboldai_vars.model = os.path.basename(os.path.normpath(model.path))
|
||||
@@ -1788,9 +1856,6 @@ def load_model(model_backend, initial_load=False):
|
||||
if not os.path.exists("./softprompts"):
|
||||
os.mkdir("./softprompts")
|
||||
koboldai_vars.splist = [[f, get_softprompt_desc(os.path.join("./softprompts", f),None,True)] for f in os.listdir("./softprompts") if os.path.isfile(os.path.join("./softprompts", f)) and valid_softprompt(os.path.join("./softprompts", f))]
|
||||
if initial_load and koboldai_vars.cloudflare_link != "":
|
||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + koboldai_vars.cloudflare_link + format(colors.END))
|
||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + koboldai_vars.cloudflare_link + "/new_ui" + format(colors.END))
|
||||
|
||||
|
||||
# Setup IP Whitelisting
|
||||
@@ -6153,6 +6218,7 @@ def UI_2_load_model_button(data):
|
||||
@socketio.on('select_model')
|
||||
@logger.catch
|
||||
def UI_2_select_model(data):
|
||||
global model_backend_type_crosswalk #No idea why I have to make this a global where I don't for model_backends...
|
||||
logger.debug("Clicked on model entry: {}".format(data))
|
||||
if data["name"] in model_menu and data['ismenu'] == "true":
|
||||
emit("open_model_load_menu", {"items": [{**item.to_json(), **{"menu":data["name"]}} for item in model_menu[data["name"]] if item.should_show()]})
|
||||
@@ -6162,8 +6228,9 @@ def UI_2_select_model(data):
|
||||
valid_loaders = {}
|
||||
if data['id'] in [item.name for sublist in model_menu for item in model_menu[sublist]]:
|
||||
#Here if we have a model id that's in our menu, we explicitly use that backend
|
||||
for model_backend in set([item.model_backend for sublist in model_menu for item in model_menu[sublist] if item.name == data['id']]):
|
||||
valid_loaders[model_backend] = model_backends[model_backend].get_requested_parameters(data["name"], data["path"] if 'path' in data else None, data["menu"])
|
||||
for model_backend_type in set([item.model_backend for sublist in model_menu for item in model_menu[sublist] if item.name == data['id']]):
|
||||
for model_backend in model_backend_type_crosswalk[model_backend_type]:
|
||||
valid_loaders[model_backend] = model_backends[model_backend].get_requested_parameters(data["name"], data["path"] if 'path' in data else None, data["menu"])
|
||||
emit("selected_model_info", {"model_backends": valid_loaders})
|
||||
else:
|
||||
#Here we have a model that's not in our menu structure (either a custom model or a custom path
|
||||
@@ -10794,6 +10861,8 @@ def run():
|
||||
# delay the display of this message until after that step
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for KoboldAI Lite: {cloudflare}/lite")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for the API: {cloudflare}/api")
|
||||
else:
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")
|
||||
|
Reference in New Issue
Block a user