Model Backends now defined in the menu

This commit is contained in:
ebolam
2023-05-18 18:34:00 -04:00
parent 182ecff202
commit 4040538d34
5 changed files with 46 additions and 24 deletions

View File

@@ -178,11 +178,13 @@ class MenuModel(MenuItem):
vram_requirements: str = "",
model_type: MenuModelType = MenuModelType.HUGGINGFACE,
experimental: bool = False,
model_backend: str = "Huggingface",
) -> None:
super().__init__(label, name, experimental)
self.model_type = model_type
self.vram_requirements = vram_requirements
self.is_downloaded = is_model_downloaded(self.name)
self.model_backend = model_backend
def to_ui1(self) -> list:
return [
@@ -245,7 +247,7 @@ model_menu = {
MenuFolder("Official RWKV-4", "rwkvlist"),
MenuFolder("Untuned GPT2", "gpt2list"),
MenuFolder("Online Services", "apilist"),
MenuModel("Read Only (No AI)", "ReadOnly", model_type=MenuModelType.OTHER),
MenuModel("Read Only (No AI)", "ReadOnly", model_type=MenuModelType.OTHER, model_backend="Read Only"),
],
'adventurelist': [
MenuModel("Skein 20B", "KoboldAI/GPT-NeoX-20B-Skein", "64GB"),
@@ -369,25 +371,24 @@ model_menu = {
MenuFolder("Return to Main Menu", "mainmenu"),
],
'rwkvlist': [
MenuModel("RWKV Raven 14B", "RWKV/rwkv-raven-14b", ""),
MenuModel("RWKV Pile 14B", "RWKV/rwkv-4-14b-pile", ""),
MenuModel("RWKV Raven 7B", "RWKV/rwkv-raven-7b", ""),
MenuModel("RWKV Pile 7B", "RWKV/rwkv-4-7b-pile", ""),
MenuModel("RWKV Raven 3B", "RWKV/rwkv-raven-3b", ""),
MenuModel("RWKV Pile 3B", "RWKV/rwkv-4-3b-pile", ""),
MenuModel("RWKV Raven 1.5B", "RWKV/rwkv-raven-1b5", ""),
MenuModel("RWKV Pile 1.5B", "RWKV/rwkv-4-1b5-pile", ""),
MenuModel("RWKV Pile 430M", "RWKV/rwkv-4-430m-pile", ""),
MenuModel("RWKV Pile 169B", "RWKV/rwkv-4-169m-pile", ""),
MenuModel("RWKV Raven 14B", "RWKV/rwkv-raven-14b", "", model_backend="RWKV"),
MenuModel("RWKV Pile 14B", "RWKV/rwkv-4-14b-pile", "", model_backend="RWKV"),
MenuModel("RWKV Raven 7B", "RWKV/rwkv-raven-7b", "", model_backend="RWKV"),
MenuModel("RWKV Pile 7B", "RWKV/rwkv-4-7b-pile", "", model_backend="RWKV"),
MenuModel("RWKV Raven 3B", "RWKV/rwkv-raven-3b", "", model_backend="RWKV"),
MenuModel("RWKV Pile 3B", "RWKV/rwkv-4-3b-pile", "", model_backend="RWKV"),
MenuModel("RWKV Raven 1.5B", "RWKV/rwkv-raven-1b5", "", model_backend="RWKV"),
MenuModel("RWKV Pile 1.5B", "RWKV/rwkv-4-1b5-pile", "", model_backend="RWKV"),
MenuModel("RWKV Pile 430M", "RWKV/rwkv-4-430m-pile", "", model_backend="RWKV"),
MenuModel("RWKV Pile 169B", "RWKV/rwkv-4-169m-pile", "", model_backend="RWKV"),
MenuFolder("Return to Main Menu", "mainmenu"),
],
'apilist': [
MenuModel("GooseAI API (requires API key)", "GooseAI", model_type=MenuModelType.ONLINE_API),
MenuModel("OpenAI API (requires API key)", "OAI", model_type=MenuModelType.ONLINE_API),
MenuModel("InferKit API (requires API key)", "InferKit", model_type=MenuModelType.ONLINE_API),
MenuModel("KoboldAI API", "API", model_type=MenuModelType.ONLINE_API),
MenuModel("Basic Model API", "Colab", model_type=MenuModelType.ONLINE_API),
MenuModel("KoboldAI Horde", "CLUSTER", model_type=MenuModelType.ONLINE_API),
MenuModel("GooseAI API (requires API key)", "GooseAI", model_type=MenuModelType.ONLINE_API, model_backend="GooseAI"),
MenuModel("OpenAI API (requires API key)", "OAI", model_type=MenuModelType.ONLINE_API, model_backend="OpenAI"),
MenuModel("KoboldAI API", "API", model_type=MenuModelType.ONLINE_API, model_backend="KoboldAI API"),
MenuModel("Basic Model API", "Colab", model_type=MenuModelType.ONLINE_API, model_backend="KoboldAI Old Colab Method"),
MenuModel("KoboldAI Horde", "CLUSTER", model_type=MenuModelType.ONLINE_API, model_backend="Horde"),
MenuFolder("Return to Main Menu", "mainmenu"),
]
}
@@ -1670,6 +1671,7 @@ def load_model(model_backend, initial_load=False):
model = model_backends[model_backend]
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
logger.debug("Model Type: {}".format(koboldai_vars.model_type))
# TODO: Convert everywhere to use model.tokenizer
@@ -6136,7 +6138,7 @@ def UI_2_select_model(data):
#Get load methods
if 'path' not in data or data['path'] == "":
valid_loaders = {}
for model_backend in model_backends:
for model_backend in set([item.model_backend for sublist in model_menu for item in model_menu[sublist] if item.name == data['id']]):
valid_loaders[model_backend] = model_backends[model_backend].get_requested_parameters(data["name"], data["path"] if 'path' in data else None, data["menu"])
emit("selected_model_info", {"model_backends": valid_loaders, "preselected": "Huggingface"})
else:

View File

@@ -70,6 +70,7 @@ class model_backend(InferenceModel):
"id": "model",
"default": model_name,
"check": {"value": "", 'check': "!="},
'multiple': True,
"tooltip": "Which model to use when running OpenAI/GooseAI.",
"menu_path": "",
"refresh_model_inputs": False,
@@ -102,7 +103,7 @@ class model_backend(InferenceModel):
engines = req.json()
try:
engines = [{"text": en["name"], "value": en["name"]} for en in engines]
engines = [{"text": "all", "value": "all"}] + [{"text": en["name"], "value": en["name"]} for en in engines]
except:
logger.error(engines)
raise

View File

@@ -352,7 +352,7 @@ border-top-right-radius: var(--tabs_rounding);
grid-template-areas: "label value"
"item item"
"minlabel maxlabel";
grid-template-rows: 20px 23px 20px;
grid-template-rows: 20px auto 20px;
grid-template-columns: auto 30px;
row-gap: 0.2em;
background-color: var(--setting_background);
@@ -2124,6 +2124,13 @@ body {
cursor: pointer;
background-color: #688f1f;
}
.loadmodelsettings {
overflow-y: auto;
max-height: 50%;
}
/*----------------------------- Model Load Popup ------------------------------------------*/
#specspan, .popup_list_area .model_item .model {
@@ -3539,7 +3546,7 @@ h2 .material-icons-outlined {
}
.horde_trigger[model_model="ReadOnly"],
.horde_trigger[model_model="Read Only"],
.horde_trigger[model_model="CLUSTER"] {
display: none;
}

View File

@@ -1695,12 +1695,20 @@ function model_settings_checker() {
for (const temp of this.check_data['sum']) {
if (document.getElementById(this.id.split("|")[0] +"|" + temp + "_value")) {
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").classList.add('input_error');
if (this.check_data['check_message']) {
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").setAttribute("tooltip", this.check_data['check_message']);
} else {
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").removeAttribute("tooltip");
}
}
}
} else {
this.closest(".setting_container_model").classList.add('input_error');
if (this.check_data['check_message']) {
this.closest(".setting_container_model").setAttribute("tooltip", this.check_data['check_message']);
} else {
this.closest(".setting_container_model").removeAttribute("tooltip");
}
}
}
}
@@ -1841,6 +1849,10 @@ function selected_model_info(sent_data) {
select_element.setAttribute("data_type", item['unit']);
select_element.onchange = onchange_event;
select_element.setAttribute("refresh_model_inputs", item['refresh_model_inputs']);
if (('multiple' in item) && (item['multiple'])) {
select_element.multiple = true;
select_element.size = 10;
}
if ('check' in item) {
select_element.check_data = item['check'];
} else {

View File

@@ -48,7 +48,7 @@
</span>
<div id="loadmodellistbreadcrumbs"></div>
<div id="loadmodellistcontent" class="popup_list_area"></div>
<div id="loadmodelplugin" class="popup_load_cancel loadmodelsettings"><select id="modelplugin" class="settings_select hidden"></select></div>
<div id="loadmodelplugin" class="popup_load_cancel"><select id="modelplugin" class="settings_select hidden"></select></div>
<div id="loadmodelsettings" class="popup_load_cancel loadmodelsettings"></div>
<div class="popup_load_cancel">
<button type="button" class="btn popup_load_cancel_button action_button disabled" onclick="load_model()" id="btn_loadmodelaccept" disabled>Load</button>