Model Backends now defined in the menu

This commit is contained in:
ebolam
2023-05-18 18:34:00 -04:00
parent 182ecff202
commit 4040538d34
5 changed files with 46 additions and 24 deletions

View File

@@ -178,11 +178,13 @@ class MenuModel(MenuItem):
vram_requirements: str = "",
model_type: MenuModelType = MenuModelType.HUGGINGFACE,
experimental: bool = False,
model_backend: str = "Huggingface",
) -> None:
super().__init__(label, name, experimental)
self.model_type = model_type
self.vram_requirements = vram_requirements
self.is_downloaded = is_model_downloaded(self.name)
self.model_backend = model_backend
def to_ui1(self) -> list:
return [
@@ -245,7 +247,7 @@ model_menu = {
MenuFolder("Official RWKV-4", "rwkvlist"),
MenuFolder("Untuned GPT2", "gpt2list"),
MenuFolder("Online Services", "apilist"),
MenuModel("Read Only (No AI)", "ReadOnly", model_type=MenuModelType.OTHER),
MenuModel("Read Only (No AI)", "ReadOnly", model_type=MenuModelType.OTHER, model_backend="Read Only"),
],
'adventurelist': [
MenuModel("Skein 20B", "KoboldAI/GPT-NeoX-20B-Skein", "64GB"),
@@ -369,25 +371,24 @@ model_menu = {
MenuFolder("Return to Main Menu", "mainmenu"),
],
'rwkvlist': [
MenuModel("RWKV Raven 14B", "RWKV/rwkv-raven-14b", ""),
MenuModel("RWKV Pile 14B", "RWKV/rwkv-4-14b-pile", ""),
MenuModel("RWKV Raven 7B", "RWKV/rwkv-raven-7b", ""),
MenuModel("RWKV Pile 7B", "RWKV/rwkv-4-7b-pile", ""),
MenuModel("RWKV Raven 3B", "RWKV/rwkv-raven-3b", ""),
MenuModel("RWKV Pile 3B", "RWKV/rwkv-4-3b-pile", ""),
MenuModel("RWKV Raven 1.5B", "RWKV/rwkv-raven-1b5", ""),
MenuModel("RWKV Pile 1.5B", "RWKV/rwkv-4-1b5-pile", ""),
MenuModel("RWKV Pile 430M", "RWKV/rwkv-4-430m-pile", ""),
MenuModel("RWKV Pile 169B", "RWKV/rwkv-4-169m-pile", ""),
MenuModel("RWKV Raven 14B", "RWKV/rwkv-raven-14b", "", model_backend="RWKV"),
MenuModel("RWKV Pile 14B", "RWKV/rwkv-4-14b-pile", "", model_backend="RWKV"),
MenuModel("RWKV Raven 7B", "RWKV/rwkv-raven-7b", "", model_backend="RWKV"),
MenuModel("RWKV Pile 7B", "RWKV/rwkv-4-7b-pile", "", model_backend="RWKV"),
MenuModel("RWKV Raven 3B", "RWKV/rwkv-raven-3b", "", model_backend="RWKV"),
MenuModel("RWKV Pile 3B", "RWKV/rwkv-4-3b-pile", "", model_backend="RWKV"),
MenuModel("RWKV Raven 1.5B", "RWKV/rwkv-raven-1b5", "", model_backend="RWKV"),
MenuModel("RWKV Pile 1.5B", "RWKV/rwkv-4-1b5-pile", "", model_backend="RWKV"),
MenuModel("RWKV Pile 430M", "RWKV/rwkv-4-430m-pile", "", model_backend="RWKV"),
MenuModel("RWKV Pile 169B", "RWKV/rwkv-4-169m-pile", "", model_backend="RWKV"),
MenuFolder("Return to Main Menu", "mainmenu"),
],
'apilist': [
MenuModel("GooseAI API (requires API key)", "GooseAI", model_type=MenuModelType.ONLINE_API),
MenuModel("OpenAI API (requires API key)", "OAI", model_type=MenuModelType.ONLINE_API),
MenuModel("InferKit API (requires API key)", "InferKit", model_type=MenuModelType.ONLINE_API),
MenuModel("KoboldAI API", "API", model_type=MenuModelType.ONLINE_API),
MenuModel("Basic Model API", "Colab", model_type=MenuModelType.ONLINE_API),
MenuModel("KoboldAI Horde", "CLUSTER", model_type=MenuModelType.ONLINE_API),
MenuModel("GooseAI API (requires API key)", "GooseAI", model_type=MenuModelType.ONLINE_API, model_backend="GooseAI"),
MenuModel("OpenAI API (requires API key)", "OAI", model_type=MenuModelType.ONLINE_API, model_backend="OpenAI"),
MenuModel("KoboldAI API", "API", model_type=MenuModelType.ONLINE_API, model_backend="KoboldAI API"),
MenuModel("Basic Model API", "Colab", model_type=MenuModelType.ONLINE_API, model_backend="KoboldAI Old Colab Method"),
MenuModel("KoboldAI Horde", "CLUSTER", model_type=MenuModelType.ONLINE_API, model_backend="Horde"),
MenuFolder("Return to Main Menu", "mainmenu"),
]
}
@@ -1670,6 +1671,7 @@ def load_model(model_backend, initial_load=False):
model = model_backends[model_backend]
model.load(initial_load=initial_load, save_model=not (args.colab or args.cacheonly) or args.savemodel)
koboldai_vars.model = model.model_name if "model_name" in vars(model) else model.id #Should have model_name, but it could be set to id depending on how it's setup
logger.debug("Model Type: {}".format(koboldai_vars.model_type))
# TODO: Convert everywhere to use model.tokenizer
@@ -6136,7 +6138,7 @@ def UI_2_select_model(data):
#Get load methods
if 'path' not in data or data['path'] == "":
valid_loaders = {}
for model_backend in model_backends:
for model_backend in set([item.model_backend for sublist in model_menu for item in model_menu[sublist] if item.name == data['id']]):
valid_loaders[model_backend] = model_backends[model_backend].get_requested_parameters(data["name"], data["path"] if 'path' in data else None, data["menu"])
emit("selected_model_info", {"model_backends": valid_loaders, "preselected": "Huggingface"})
else: