mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
First concept of model plugins with a conceptual UI.
Completely breaks UI2 model loading.
This commit is contained in:
123
aiserver.py
123
aiserver.py
@@ -168,6 +168,7 @@ class MenuFolder(MenuItem):
|
||||
"size": "",
|
||||
"isMenu": True,
|
||||
"isDownloaded": False,
|
||||
"isDirectory": False
|
||||
}
|
||||
|
||||
class MenuModel(MenuItem):
|
||||
@@ -200,8 +201,28 @@ class MenuModel(MenuItem):
|
||||
"size": self.vram_requirements,
|
||||
"isMenu": False,
|
||||
"isDownloaded": self.is_downloaded,
|
||||
"isDirectory": False,
|
||||
}
|
||||
|
||||
class MenuPath(MenuItem):
|
||||
def to_ui1(self) -> list:
|
||||
return [
|
||||
self.label,
|
||||
self.name,
|
||||
"",
|
||||
True,
|
||||
]
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"label": self.label,
|
||||
"name": self.name,
|
||||
"size": "",
|
||||
"isMenu": True,
|
||||
"isDownloaded": False,
|
||||
"isDirectory": True,
|
||||
"path": "./models"
|
||||
}
|
||||
|
||||
# AI models Menu
|
||||
# This is a dict of lists where they key is the menu name, and the list is the menu items.
|
||||
@@ -209,8 +230,8 @@ class MenuModel(MenuItem):
|
||||
# 3: the memory requirement for the model, 4: if the item is a menu or not (True/False)
|
||||
model_menu = {
|
||||
"mainmenu": [
|
||||
MenuModel("Load a model from its directory", "NeoCustom"),
|
||||
MenuModel("Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom"),
|
||||
MenuPath("Load a model from its directory", "NeoCustom"),
|
||||
MenuPath("Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom"),
|
||||
MenuFolder("Load custom model from Hugging Face", "customhuggingface"),
|
||||
MenuFolder("Adventure Models", "adventurelist"),
|
||||
MenuFolder("Novel Models", "novellist"),
|
||||
@@ -600,6 +621,15 @@ utils.socketio = socketio
|
||||
# Weird import position to steal koboldai_vars from utils
|
||||
from modeling.patches import patch_transformers
|
||||
|
||||
#Load all of the model importers
|
||||
import importlib
|
||||
model_loader_code = {}
|
||||
model_loaders = {}
|
||||
for module in os.listdir("./modeling/inference_models"):
|
||||
if os.path.isfile(os.path.join("./modeling/inference_models",module)) and module[-3:] == '.py':
|
||||
model_loader_code[module[:-3]] = importlib.import_module('modeling.inference_models.{}'.format(module[:-3]))
|
||||
model_loaders[module[:-3]] = model_loader_code[module[:-3]].model_loader()
|
||||
|
||||
|
||||
old_socketio_on = socketio.on
|
||||
def new_socketio_on(*a, **k):
|
||||
@@ -906,6 +936,8 @@ def sendModelSelection(menu="mainmenu", folder="./models"):
|
||||
)
|
||||
|
||||
def get_folder_path_info(base):
|
||||
if base is None:
|
||||
return [], []
|
||||
if base == 'This PC':
|
||||
breadcrumbs = [['This PC', 'This PC']]
|
||||
paths = [["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]
|
||||
@@ -1932,25 +1964,25 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
koboldai_vars.breakmodel = False
|
||||
|
||||
if koboldai_vars.model == "Colab":
|
||||
from modeling.inference_models.basic_api import BasicAPIInferenceModel
|
||||
model = BasicAPIInferenceModel()
|
||||
from modeling.inference_models.basic_api import model_loader
|
||||
model = model_loader()
|
||||
elif koboldai_vars.model == "API":
|
||||
from modeling.inference_models.api import APIInferenceModel
|
||||
model = APIInferenceModel(koboldai_vars.colaburl.replace("/request", ""))
|
||||
from modeling.inference_models.api import model_loader
|
||||
model = model_loader(koboldai_vars.colaburl.replace("/request", ""))
|
||||
elif koboldai_vars.model == "CLUSTER":
|
||||
from modeling.inference_models.horde import HordeInferenceModel
|
||||
model = HordeInferenceModel()
|
||||
from modeling.inference_models.horde import model_loader
|
||||
model = model_loader()
|
||||
elif koboldai_vars.model == "OAI":
|
||||
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||
model = OpenAIAPIInferenceModel()
|
||||
from modeling.inference_models.openai import model_loader
|
||||
model = model_loader()
|
||||
|
||||
model.load(initial_load=initial_load)
|
||||
# TODO: This check sucks, make a model object or somethign
|
||||
elif "rwkv" in koboldai_vars.model:
|
||||
if koboldai_vars.use_colab_tpu:
|
||||
raise RuntimeError("RWKV is not supported on the TPU.")
|
||||
from modeling.inference_models.rwkv import RWKVInferenceModel
|
||||
model = RWKVInferenceModel(koboldai_vars.model)
|
||||
from modeling.inference_models.rwkv import model_loader
|
||||
model = model_loader(koboldai_vars.model)
|
||||
model.load()
|
||||
elif not koboldai_vars.use_colab_tpu and not koboldai_vars.noai:
|
||||
# HF Torch
|
||||
@@ -1961,8 +1993,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
except:
|
||||
pass
|
||||
|
||||
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||
model = GenericHFTorchInferenceModel(
|
||||
from modeling.inference_models.generic_hf_torch import model_loader
|
||||
model = model_loader(
|
||||
koboldai_vars.model,
|
||||
lazy_load=koboldai_vars.lazy_load,
|
||||
low_mem=args.lowmem
|
||||
@@ -1975,8 +2007,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
||||
else:
|
||||
# TPU
|
||||
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||
model = HFMTJInferenceModel(
|
||||
from modeling.inference_models.hf_mtj import model_loader
|
||||
model = model_loader(
|
||||
koboldai_vars.model
|
||||
)
|
||||
model.load(
|
||||
@@ -6430,7 +6462,9 @@ def UI_2_retry(data):
|
||||
@socketio.on('load_model_button')
|
||||
@logger.catch
|
||||
def UI_2_load_model_button(data):
|
||||
sendModelSelection()
|
||||
emit("open_model_load_menu", {"items": [{**item.to_json(), **{"menu":"mainmenu"}} for item in model_menu['mainmenu'] if item.should_show()]})
|
||||
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Event triggered when user clicks the a model
|
||||
@@ -6438,6 +6472,38 @@ def UI_2_load_model_button(data):
|
||||
@socketio.on('select_model')
|
||||
@logger.catch
|
||||
def UI_2_select_model(data):
|
||||
logger.debug("Clicked on model entry: {}".format(data))
|
||||
if data["name"] in model_menu and data['ismenu'] == "true":
|
||||
emit("open_model_load_menu", {"items": [{**item.to_json(), **{"menu":data["name"]}} for item in model_menu[data["name"]] if item.should_show()]})
|
||||
else:
|
||||
#Get load methods
|
||||
logger.debug("Asking for model info on potential model: {}".format(data))
|
||||
valid = False
|
||||
if 'path' not in data or data['path'] == "":
|
||||
valid_loaders = {}
|
||||
for model_loader in model_loaders:
|
||||
logger.debug("Testing Loader {} for model {}: {}".format(model_loader, data["name"], model_loaders[model_loader].is_valid(data["name"], data["path"] if 'path' in data else None, data["menu"])))
|
||||
if model_loaders[model_loader].is_valid(data["name"], data["path"] if 'path' in data else None, data["menu"]):
|
||||
valid_loaders[model_loader] = model_loaders[model_loader].get_requested_parameters(data["name"], data["path"] if 'path' in data else None, data["menu"])
|
||||
valid = True
|
||||
if valid:
|
||||
logger.debug("Valid Loaders: {}".format(valid_loaders))
|
||||
emit("selected_model_info", valid_loaders)
|
||||
if not valid:
|
||||
#Get directories
|
||||
paths, breadcrumbs = get_folder_path_info(data['path'])
|
||||
output = []
|
||||
for path in paths:
|
||||
valid=False
|
||||
for model_loader in model_loaders:
|
||||
if model_loaders[model_loader].is_valid(path[1], path[0], "Custom"):
|
||||
valid=True
|
||||
break
|
||||
output.append({'label': path[1], 'name': path[0], 'size': "", "menu": "Custom", 'path': path[0], 'isMenu': not valid})
|
||||
emit("open_model_load_menu", {"items": output+[{'label': 'Return to Main Menu', 'name':'mainmenu', 'size': "", "menu": "Custom", 'isMenu': True}], 'breadcrumbs': breadcrumbs})
|
||||
|
||||
return
|
||||
|
||||
|
||||
#We've selected a menu
|
||||
if data['model'] in model_menu:
|
||||
@@ -6462,26 +6528,9 @@ def UI_2_select_model(data):
|
||||
@socketio.on('load_model')
|
||||
@logger.catch
|
||||
def UI_2_load_model(data):
|
||||
if not os.path.exists("settings/"):
|
||||
os.mkdir("settings")
|
||||
changed = True
|
||||
if os.path.exists("settings/" + data['model'].replace('/', '_') + ".breakmodel"):
|
||||
with open("settings/" + data['model'].replace('/', '_') + ".breakmodel", "r") as file:
|
||||
file_data = file.read().split('\n')[:2]
|
||||
if len(file_data) < 2:
|
||||
file_data.append("0")
|
||||
gpu_layers, disk_layers = file_data
|
||||
if gpu_layers == data['gpu_layers'] and disk_layers == data['disk_layers']:
|
||||
changed = False
|
||||
if changed:
|
||||
f = open("settings/" + data['model'].replace('/', '_') + ".breakmodel", "w")
|
||||
f.write("{}\n{}".format(data['gpu_layers'], data['disk_layers']))
|
||||
f.close()
|
||||
koboldai_vars.colaburl = data['url'] + "/request"
|
||||
koboldai_vars.model = data['model']
|
||||
koboldai_vars.custmodpth = data['path']
|
||||
print("loading Model")
|
||||
load_model(use_gpu=data['use_gpu'], gpu_layers=data['gpu_layers'], disk_layers=data['disk_layers'], online_model=data['online_model'], url=koboldai_vars.colaburl, use_8_bit=data['use_8_bit'])
|
||||
logger.info("loading Model")
|
||||
logger.info(data)
|
||||
#load_model(use_gpu=data['use_gpu'], gpu_layers=data['gpu_layers'], disk_layers=data['disk_layers'], online_model=data['online_model'], url=koboldai_vars.colaburl, use_8_bit=data['use_8_bit'])
|
||||
|
||||
#==================================================================#
|
||||
# Event triggered when load story is clicked
|
||||
|
@@ -169,6 +169,15 @@ class InferenceModel:
|
||||
]
|
||||
self.tokenizer = None
|
||||
self.capabilties = ModelCapabilities()
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path, vram):
|
||||
return True
|
||||
|
||||
def requested_parameters(self, model_name, model_path, menu_path, vram):
|
||||
return {}
|
||||
|
||||
def define_input_parameters(self):
|
||||
return
|
||||
|
||||
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
|
||||
"""User-facing load function. Do not override this; try `_load()` instead."""
|
||||
|
@@ -22,9 +22,31 @@ class APIException(Exception):
|
||||
"""To be used for errors when using the Kobold API as an interface."""
|
||||
|
||||
|
||||
class APIInferenceModel(InferenceModel):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
class model_loader(InferenceModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
#self.base_url = ""
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
return model_name == "API"
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
requested_parameters = []
|
||||
requested_parameters.append({
|
||||
"uitype": "text",
|
||||
"unit": "text",
|
||||
"label": "URL",
|
||||
"id": "base_url",
|
||||
"default": False,
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"tooltip": "The URL of the KoboldAI API to connect to.",
|
||||
"menu_path": "",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self, base_url=""):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
|
@@ -19,12 +19,37 @@ class BasicAPIException(Exception):
|
||||
"""To be used for errors when using the Basic API as an interface."""
|
||||
|
||||
|
||||
class BasicAPIInferenceModel(InferenceModel):
|
||||
class model_loader(InferenceModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Do not allow API to be served over the API
|
||||
self.capabilties = ModelCapabilities(api_host=False)
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
return model_name == "Colab"
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
requested_parameters = []
|
||||
requested_parameters.append({
|
||||
"uitype": "text",
|
||||
"unit": "text",
|
||||
"label": "URL",
|
||||
"id": "colaburl",
|
||||
"default": False,
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"tooltip": "The URL of the Colab KoboldAI API to connect to.",
|
||||
"menu_path": "",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self, colaburl=""):
|
||||
self.colaburl = colaburl
|
||||
|
||||
def _initialize_model(self):
|
||||
return
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.tokenizer = self._get_tokenizer("EleutherAI/gpt-neo-2.7B")
|
||||
@@ -68,7 +93,7 @@ class BasicAPIInferenceModel(InferenceModel):
|
||||
}
|
||||
|
||||
# Create request
|
||||
req = requests.post(utils.koboldai_vars.colaburl, json=reqdata)
|
||||
req = requests.post(self.colaburl, json=reqdata)
|
||||
|
||||
if req.status_code != 200:
|
||||
raise BasicAPIException(f"Bad status code {req.status_code}")
|
||||
|
@@ -20,10 +20,14 @@ except ModuleNotFoundError as e:
|
||||
if not utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
||||
from modeling.inference_models.parents.hf_torch import HFTorchInferenceModel
|
||||
|
||||
|
||||
class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||
class model_loader(HFTorchInferenceModel):
|
||||
|
||||
def _initialize_model(self):
|
||||
return
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
utils.koboldai_vars.allowsp = True
|
||||
|
||||
|
@@ -1,190 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from transformers import AutoConfig
|
||||
|
||||
import utils
|
||||
import koboldai_settings
|
||||
from logger import logger
|
||||
from modeling.inference_model import InferenceModel
|
||||
|
||||
|
||||
class HFInferenceModel(InferenceModel):
|
||||
def __init__(self, model_name: str) -> None:
|
||||
super().__init__()
|
||||
self.model_config = None
|
||||
self.model_name = model_name
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
def _post_load(self) -> None:
|
||||
# These are model specific tokenizer overrides if a model has bad defaults
|
||||
if utils.koboldai_vars.model_type == "llama":
|
||||
# Note: self.tokenizer is a GenericTokenizer, and self.tokenizer.tokenizer is the actual LlamaTokenizer
|
||||
self.tokenizer.add_bos_token = False
|
||||
|
||||
# HF transformers no longer supports decode_with_prefix_space
|
||||
# We work around this by wrapping decode, encode, and __call__
|
||||
# with versions that work around the 'prefix space' misfeature
|
||||
# of sentencepiece.
|
||||
vocab = self.tokenizer.convert_ids_to_tokens(range(self.tokenizer.vocab_size))
|
||||
has_prefix_space = {i for i, tok in enumerate(vocab) if tok.startswith("▁")}
|
||||
|
||||
# Wrap 'decode' with a method that always returns text starting with a space
|
||||
# when the head token starts with a space. This is what 'decode_with_prefix_space'
|
||||
# used to do, and we implement it using the same technique (building a cache of
|
||||
# tokens that should have a prefix space, and then prepending a space if the first
|
||||
# token is in this set.) We also work around a bizarre behavior in which decoding
|
||||
# a single token 13 behaves differently than decoding a squence containing only [13].
|
||||
original_decode = type(self.tokenizer.tokenizer).decode
|
||||
def decode_wrapper(self, token_ids, *args, **kwargs):
|
||||
first = None
|
||||
# Note, the code below that wraps single-value token_ids in a list
|
||||
# is to work around this wonky behavior:
|
||||
# >>> t.decode(13)
|
||||
# '<0x0A>'
|
||||
# >>> t.decode([13])
|
||||
# '\n'
|
||||
# Not doing this causes token streaming to receive <0x0A> characters
|
||||
# instead of newlines.
|
||||
if isinstance(token_ids, int):
|
||||
first = token_ids
|
||||
token_ids = [first]
|
||||
elif hasattr(token_ids, 'dim'): # Check for e.g. torch.Tensor
|
||||
# Tensors don't support the Python standard of 'empty is False'
|
||||
# and the special case of dimension 0 tensors also needs to be
|
||||
# handled separately.
|
||||
if token_ids.dim() == 0:
|
||||
first = int(token_ids.item())
|
||||
token_ids = [first]
|
||||
elif len(token_ids) > 0:
|
||||
first = int(token_ids[0])
|
||||
elif token_ids:
|
||||
first = token_ids[0]
|
||||
result = original_decode(self, token_ids, *args, **kwargs)
|
||||
if first is not None and first in has_prefix_space:
|
||||
result = " " + result
|
||||
return result
|
||||
# GenericTokenizer overrides __setattr__ so we need to use object.__setattr__ to bypass it
|
||||
object.__setattr__(self.tokenizer, 'decode', decode_wrapper.__get__(self.tokenizer))
|
||||
|
||||
# Wrap encode and __call__ to work around the 'prefix space' misfeature also.
|
||||
# The problem is that "Bob" at the start of text is encoded as if it is
|
||||
# " Bob". This creates a problem because it means you can't split text, encode
|
||||
# the pieces, concatenate the tokens, decode them, and get the original text back.
|
||||
# The workaround is to prepend a known token that (1) starts with a space; and
|
||||
# (2) is not the prefix of any other token. After searching through the vocab
|
||||
# " ," (space comma) is the only token containing only printable ascii characters
|
||||
# that fits this bill. By prepending ',' to the text, the original encode
|
||||
# method always returns [1919, ...], where the tail of the sequence is the
|
||||
# actual encoded result we want without the prefix space behavior.
|
||||
original_encode = type(self.tokenizer.tokenizer).encode
|
||||
def encode_wrapper(self, text, *args, **kwargs):
|
||||
if type(text) is str:
|
||||
text = ',' + text
|
||||
result = original_encode(self, text, *args, **kwargs)
|
||||
result = result[1:]
|
||||
else:
|
||||
result = original_encode(self, text, *args, **kwargs)
|
||||
return result
|
||||
object.__setattr__(self.tokenizer, 'encode', encode_wrapper.__get__(self.tokenizer))
|
||||
|
||||
# Since 'encode' is documented as being deprecated, also override __call__.
|
||||
# This doesn't appear to currently be used by KoboldAI, but doing so
|
||||
# in case someone uses it in the future.
|
||||
original_call = type(self.tokenizer.tokenizer).__call__
|
||||
def call_wrapper(self, text, *args, **kwargs):
|
||||
if type(text) is str:
|
||||
text = ',' + text
|
||||
result = original_call(self, text, *args, **kwargs)
|
||||
result = result[1:]
|
||||
else:
|
||||
result = original_call(self, text, *args, **kwargs)
|
||||
return result
|
||||
object.__setattr__(self.tokenizer, '__call__', call_wrapper.__get__(self.tokenizer))
|
||||
|
||||
elif utils.koboldai_vars.model_type == "opt":
|
||||
self.tokenizer._koboldai_header = self.tokenizer.encode("")
|
||||
self.tokenizer.add_bos_token = False
|
||||
self.tokenizer.add_prefix_space = False
|
||||
|
||||
# Change newline behavior to match model quirks
|
||||
if utils.koboldai_vars.model_type == "xglm":
|
||||
# Default to </s> newline mode if using XGLM
|
||||
utils.koboldai_vars.newlinemode = "s"
|
||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||
utils.koboldai_vars.newlinemode = "ns"
|
||||
|
||||
# Clean up tokens that cause issues
|
||||
if (
|
||||
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
||||
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||
):
|
||||
utils.koboldai_vars.badwordsids = [
|
||||
[v]
|
||||
for k, v in self.tokenizer.get_vocab().items()
|
||||
if any(c in str(k) for c in "[]")
|
||||
]
|
||||
|
||||
if utils.koboldai_vars.newlinemode == "n":
|
||||
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
|
||||
|
||||
return super()._post_load()
|
||||
|
||||
def get_local_model_path(
|
||||
self, legacy: bool = False, ignore_existance: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns a string of the model's path locally, or None if it is not downloaded.
|
||||
If ignore_existance is true, it will always return a path.
|
||||
"""
|
||||
|
||||
if self.model_name in ["NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]:
|
||||
model_path = utils.koboldai_vars.custmodpth
|
||||
assert model_path
|
||||
|
||||
# Path can be absolute or relative to models directory
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
model_path = os.path.join("models", model_path)
|
||||
|
||||
try:
|
||||
assert os.path.exists(model_path)
|
||||
except AssertionError:
|
||||
logger.error(f"Custom model does not exist at '{utils.koboldai_vars.custmodpth}' or '{model_path}'.")
|
||||
raise
|
||||
|
||||
return model_path
|
||||
|
||||
basename = utils.koboldai_vars.model.replace("/", "_")
|
||||
if legacy:
|
||||
ret = basename
|
||||
else:
|
||||
ret = os.path.join("models", basename)
|
||||
|
||||
if os.path.isdir(ret) or ignore_existance:
|
||||
return ret
|
||||
return None
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
try:
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.get_local_model_path() or self.model_name,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
utils.koboldai_vars.model_type = self.model_config.model_type
|
||||
except ValueError:
|
||||
utils.koboldai_vars.model_type = {
|
||||
"NeoCustom": "gpt_neo",
|
||||
"GPT2Custom": "gpt2",
|
||||
}.get(utils.koboldai_vars.model)
|
||||
|
||||
if not utils.koboldai_vars.model_type:
|
||||
logger.warning(
|
||||
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||
)
|
||||
utils.koboldai_vars.model_type = "gpt_neo"
|
@@ -16,19 +16,17 @@ from modeling.inference_model import (
|
||||
GenerationSettings,
|
||||
ModelCapabilities,
|
||||
)
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
|
||||
# This file shouldn't be imported unless using the TPU
|
||||
assert utils.koboldai_vars.use_colab_tpu
|
||||
import tpu_mtj_backend
|
||||
from modeling.inference_models.parents.hf import HFInferenceModel
|
||||
|
||||
|
||||
class HFMTJInferenceModel(HFInferenceModel):
|
||||
|
||||
|
||||
class model_loader(HFInferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
#model_name: str,
|
||||
) -> None:
|
||||
super().__init__(model_name)
|
||||
super().__init__()
|
||||
|
||||
self.model_config = None
|
||||
self.capabilties = ModelCapabilities(
|
||||
@@ -38,8 +36,13 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
post_token_probs=False,
|
||||
uses_tpu=True,
|
||||
)
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
# This file shouldn't be imported unless using the TPU
|
||||
return utils.koboldai_vars.use_colab_tpu and super().is_valid(model_name, model_path, menu_path)
|
||||
|
||||
def setup_mtj(self) -> None:
|
||||
import tpu_mtj_backend
|
||||
def mtj_warper_callback(scores) -> "np.array":
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
@@ -175,6 +178,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
tpu_mtj_backend.settings_callback = mtj_settings_callback
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
import tpu_mtj_backend
|
||||
self.setup_mtj()
|
||||
self.init_model_config()
|
||||
utils.koboldai_vars.allowsp = True
|
||||
@@ -207,6 +211,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
]
|
||||
|
||||
def get_soft_tokens(self) -> np.array:
|
||||
import tpu_mtj_backend
|
||||
soft_tokens = None
|
||||
|
||||
if utils.koboldai_vars.sp is None:
|
||||
@@ -258,6 +263,7 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
import tpu_mtj_backend
|
||||
warpers.update_settings()
|
||||
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
|
@@ -21,13 +21,99 @@ class HordeException(Exception):
|
||||
"""To be used for errors on server side of the Horde."""
|
||||
|
||||
|
||||
class HordeInferenceModel(InferenceModel):
|
||||
class model_loader(InferenceModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.url = "https://horde.koboldai.net"
|
||||
self.key = "0000000000"
|
||||
self.models = self.get_cluster_models()
|
||||
|
||||
|
||||
# Do not allow API to be served over the API
|
||||
self.capabilties = ModelCapabilities(api_host=False)
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
logger.debug("Horde Models: {}".format(self.models))
|
||||
return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models]
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
requested_parameters = []
|
||||
requested_parameters.extend([{
|
||||
"uitype": "text",
|
||||
"unit": "text",
|
||||
"label": "URL",
|
||||
"id": "url",
|
||||
"default": self.url,
|
||||
"tooltip": "URL to the horde.",
|
||||
"menu_path": "",
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"refresh_model_inputs": True,
|
||||
"extra_classes": ""
|
||||
},
|
||||
{
|
||||
"uitype": "text",
|
||||
"unit": "text",
|
||||
"label": "Key",
|
||||
"id": "key",
|
||||
"default": self.key,
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"tooltip": "User Key to use when connecting to Horde (0000000000 is anonymous).",
|
||||
"menu_path": "",
|
||||
"refresh_model_inputs": True,
|
||||
"extra_classes": ""
|
||||
},
|
||||
{
|
||||
"uitype": "dropdown",
|
||||
"unit": "text",
|
||||
"label": "Model",
|
||||
"id": "model",
|
||||
"default": "",
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"tooltip": "Which model to use when running OpenAI/GooseAI.",
|
||||
"menu_path": "",
|
||||
"refresh_model_inputs": False,
|
||||
"extra_classes": "",
|
||||
'children': self.models,
|
||||
|
||||
}])
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self, url="", key="", model=""):
|
||||
self.key = key.strip()
|
||||
self.model = model
|
||||
self.url = url
|
||||
|
||||
def get_cluster_models(self):
|
||||
# Get list of models from public cluster
|
||||
logger.info("<purple>Retrieving engine list...</purple>")
|
||||
try:
|
||||
req = requests.get(f"{self.url}/api/v2/status/models?type=text")
|
||||
except:
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error("Provided KoboldAI Horde URL unreachable")
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
|
||||
return
|
||||
if not req.ok:
|
||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error(req.json())
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()}, room="UI_1")
|
||||
return
|
||||
|
||||
engines = req.json()
|
||||
try:
|
||||
engines = [{"text": en["name"], "value": en["name"]} for en in engines]
|
||||
except:
|
||||
logger.error(engines)
|
||||
raise
|
||||
logger.debug(engines)
|
||||
|
||||
online_model = ""
|
||||
|
||||
logger.init_ok("KAI Horde Models", status="OK")
|
||||
|
||||
return engines
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.tokenizer = self._get_tokenizer(
|
||||
utils.koboldai_vars.cluster_requested_models[0]
|
||||
|
@@ -12,13 +12,96 @@ from modeling.inference_model import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
class OpenAIAPIError(Exception):
|
||||
def __init__(self, error_type: str, error_message) -> None:
|
||||
super().__init__(f"{error_type}: {error_message}")
|
||||
|
||||
|
||||
class OpenAIAPIInferenceModel(InferenceModel):
|
||||
class model_loader(InferenceModel):
|
||||
"""InferenceModel for interfacing with OpenAI's generation API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.key = ""
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
return model_name == "OAI" or model_name == "GooseAI"
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
self.source = model_name
|
||||
requested_parameters = []
|
||||
requested_parameters.extend([{
|
||||
"uitype": "text",
|
||||
"unit": "text",
|
||||
"label": "Key",
|
||||
"id": "key",
|
||||
"default": "",
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"tooltip": "User Key to use when connecting to OpenAI/GooseAI.",
|
||||
"menu_path": "",
|
||||
"refresh_model_inputs": True,
|
||||
"extra_classes": ""
|
||||
},
|
||||
{
|
||||
"uitype": "dropdown",
|
||||
"unit": "text",
|
||||
"label": "Model",
|
||||
"id": "model",
|
||||
"default": "",
|
||||
"check": {"value": "", 'check': "!="},
|
||||
"tooltip": "Which model to use when running OpenAI/GooseAI.",
|
||||
"menu_path": "",
|
||||
"refresh_model_inputs": False,
|
||||
"extra_classes": "",
|
||||
'children': self.get_oai_models(),
|
||||
|
||||
}])
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self, key="", model=""):
|
||||
self.key = key.strip()
|
||||
self.model = model
|
||||
|
||||
def get_oai_models(self):
|
||||
if self.key == "":
|
||||
return []
|
||||
if self.source == 'OAI':
|
||||
url = "https://api.openai.com/v1/engines"
|
||||
elif self.source == 'GooseAI':
|
||||
url = "https://api.goose.ai/v1/engines"
|
||||
else:
|
||||
return
|
||||
|
||||
# Get list of models from OAI
|
||||
logger.init("OAI Engines", status="Retrieving")
|
||||
req = requests.get(
|
||||
url,
|
||||
headers = {
|
||||
'Authorization': 'Bearer '+self.key
|
||||
}
|
||||
)
|
||||
if(req.status_code == 200):
|
||||
r = req.json()
|
||||
engines = r["data"]
|
||||
try:
|
||||
engines = [{"value": en["id"], "text": "{} ({})".format(en['id'], "Ready" if en["ready"] == True else "Not Ready")} for en in engines]
|
||||
except:
|
||||
logger.error(engines)
|
||||
raise
|
||||
|
||||
online_model = ""
|
||||
|
||||
|
||||
logger.init_ok("OAI Engines", status="OK")
|
||||
return engines
|
||||
else:
|
||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||
logger.init_err("OAI Engines", status="Failed")
|
||||
logger.error(req.json())
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
||||
return []
|
||||
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.tokenizer = self._get_tokenizer("gpt2")
|
||||
|
219
modeling/inference_models/parents/hf.py
Normal file
219
modeling/inference_models/parents/hf.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from transformers import AutoConfig
|
||||
|
||||
import utils
|
||||
import koboldai_settings
|
||||
from logger import logger
|
||||
from modeling.inference_model import InferenceModel
|
||||
import torch
|
||||
|
||||
|
||||
class HFInferenceModel(InferenceModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model_config = None
|
||||
#self.model_name = model_name
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
try:
|
||||
if model_path is not None and os.path.exists(model_path):
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
elif(os.path.exists("models/{}".format(model_name.replace('/', '_')))):
|
||||
model_config = AutoConfig.from_pretrained("models/{}".format(model_name.replace('/', '_')), revision=utils.koboldai_vars.revision, cache_dir="cache")
|
||||
else:
|
||||
model_config = AutoConfig.from_pretrained(model_name, revision=utils.koboldai_vars.revision, cache_dir="cache")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
requested_parameters = []
|
||||
|
||||
if model_path is not None and os.path.exists(model_path):
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
elif(os.path.exists("models/{}".format(model_name.replace('/', '_')))):
|
||||
model_config = AutoConfig.from_pretrained("models/{}".format(model_name.replace('/', '_')), revision=utils.koboldai_vars.revision, cache_dir="cache")
|
||||
else:
|
||||
model_config = AutoConfig.from_pretrained(model_name, revision=utils.koboldai_vars.revision, cache_dir="cache")
|
||||
layer_count = model_config["n_layer"] if isinstance(model_config, dict) else model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer if hasattr(model_config, "n_layer") else model_config.num_hidden_layers if hasattr(model_config, 'num_hidden_layers') else None
|
||||
if layer_count is not None and layer_count >= 0:
|
||||
if os.path.exists("settings/{}.breakmodel".format(model_name.replace("/", "_"))):
|
||||
with open("settings/{}.breakmodel".format(model_name.replace("/", "_")), "r") as file:
|
||||
data = [x for x in file.read().split("\n")[:2] if x != '']
|
||||
if len(data) < 2:
|
||||
data.append("0")
|
||||
break_values, disk_blocks = data
|
||||
break_values = break_values.split(",")
|
||||
else:
|
||||
break_values = [layer_count]
|
||||
disk_blocks = None
|
||||
break_values = [int(x) for x in break_values if x != '' and x is not None]
|
||||
gpu_count = torch.cuda.device_count()
|
||||
break_values += [0] * (gpu_count - len(break_values))
|
||||
if disk_blocks is not None:
|
||||
break_values += [disk_blocks]
|
||||
for i in range(gpu_count):
|
||||
requested_parameters.append({
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "{} Layers".format(torch.cuda.get_device_name(i)),
|
||||
"id": "{} Layers".format(i),
|
||||
"min": 0,
|
||||
"max": layer_count,
|
||||
"step": 1,
|
||||
"check": {"sum": ["{} Layers".format(i) for i in range(gpu_count)]+['CPU Layers']+(['Disk_Layers'] if disk_blocks is not None else []), "value": layer_count, 'check': "="},
|
||||
"check_message": "The sum of assigned layers must equal {}".format(layer_count),
|
||||
"default": break_values[i],
|
||||
"tooltip": "The number of layers to put on {}.".format(torch.cuda.get_device_name(i)),
|
||||
"menu_path": "Layers",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
requested_parameters.append({
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "CPU Layers",
|
||||
"id": "CPU Layers",
|
||||
"min": 0,
|
||||
"max": layer_count,
|
||||
"step": 1,
|
||||
"check": {"sum": ["{} Layers".format(i) for i in range(gpu_count)]+['CPU Layers']+(['Disk_Layers'] if disk_blocks is not None else []), "value": layer_count, 'check': "="},
|
||||
"check_message": "The sum of assigned layers must equal {}".format(layer_count),
|
||||
"default": layer_count - sum(break_values),
|
||||
"tooltip": "The number of layers to put on the CPU. This will use your system RAM. It will also do inference partially on CPU. Use if you must.",
|
||||
"menu_path": "Layers",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
if disk_blocks is not None:
|
||||
requested_parameters.append({
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Disk Layers",
|
||||
"id": "Disk_Layers",
|
||||
"min": 0,
|
||||
"max": layer_count,
|
||||
"step": 1,
|
||||
"check": {"sum": ["{} Layers".format(i) for i in range(gpu_count)]+['CPU Layers']+(['Disk_Layers'] if disk_blocks is not None else []), "value": layer_count, 'check': "="},
|
||||
"check_message": "The sum of assigned layers must equal {}".format(layer_count),
|
||||
"default": disk_blocks,
|
||||
"tooltip": "The number of layers to put on the disk. This will use your hard drive. The is VERY slow in comparison to GPU or CPU. Use as a last resort.",
|
||||
"menu_path": "Layers",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
else:
|
||||
requested_parameters.append({
|
||||
"uitype": "toggle",
|
||||
"unit": "bool",
|
||||
"label": "Use GPU",
|
||||
"id": "use_gpu",
|
||||
"default": False,
|
||||
"tooltip": "The number of layers to put on the disk. This will use your hard drive. The is VERY slow in comparison to GPU or CPU. Use as a last resort.",
|
||||
"menu_path": "Layers",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
|
||||
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self, layers=[], disk_layers=0, use_gpu=False):
|
||||
self.layers = layers
|
||||
self.disk_layers = disk_layers
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
def _post_load(self) -> None:
|
||||
# These are model specific tokenizer overrides if a model has bad defaults
|
||||
if utils.koboldai_vars.model_type == "llama":
|
||||
self.tokenizer.decode_with_prefix_space = True
|
||||
self.tokenizer.add_bos_token = False
|
||||
elif utils.koboldai_vars.model_type == "opt":
|
||||
self.tokenizer._koboldai_header = self.tokenizer.encode("")
|
||||
self.tokenizer.add_bos_token = False
|
||||
self.tokenizer.add_prefix_space = False
|
||||
|
||||
# Change newline behavior to match model quirks
|
||||
if utils.koboldai_vars.model_type == "xglm":
|
||||
# Default to </s> newline mode if using XGLM
|
||||
utils.koboldai_vars.newlinemode = "s"
|
||||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||
utils.koboldai_vars.newlinemode = "ns"
|
||||
|
||||
# Clean up tokens that cause issues
|
||||
if (
|
||||
utils.koboldai_vars.badwordsids == koboldai_settings.badwordsids_default
|
||||
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||
):
|
||||
utils.koboldai_vars.badwordsids = [
|
||||
[v]
|
||||
for k, v in self.tokenizer.get_vocab().items()
|
||||
if any(c in str(k) for c in "[]")
|
||||
]
|
||||
|
||||
if utils.koboldai_vars.newlinemode == "n":
|
||||
utils.koboldai_vars.badwordsids.append([self.tokenizer.eos_token_id])
|
||||
|
||||
return super()._post_load()
|
||||
|
||||
def get_local_model_path(
|
||||
self, legacy: bool = False, ignore_existance: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns a string of the model's path locally, or None if it is not downloaded.
|
||||
If ignore_existance is true, it will always return a path.
|
||||
"""
|
||||
|
||||
if self.model_name in ["NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]:
|
||||
model_path = utils.koboldai_vars.custmodpth
|
||||
assert model_path
|
||||
|
||||
# Path can be absolute or relative to models directory
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
model_path = os.path.join("models", model_path)
|
||||
|
||||
try:
|
||||
assert os.path.exists(model_path)
|
||||
except AssertionError:
|
||||
logger.error(f"Custom model does not exist at '{utils.koboldai_vars.custmodpth}' or '{model_path}'.")
|
||||
raise
|
||||
|
||||
return model_path
|
||||
|
||||
basename = utils.koboldai_vars.model.replace("/", "_")
|
||||
if legacy:
|
||||
ret = basename
|
||||
else:
|
||||
ret = os.path.join("models", basename)
|
||||
|
||||
if os.path.isdir(ret) or ignore_existance:
|
||||
return ret
|
||||
return None
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
# Get the model_type from the config or assume a model type if it isn't present
|
||||
try:
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.get_local_model_path() or self.model_name,
|
||||
revision=utils.koboldai_vars.revision,
|
||||
cache_dir="cache",
|
||||
)
|
||||
utils.koboldai_vars.model_type = self.model_config.model_type
|
||||
except ValueError:
|
||||
utils.koboldai_vars.model_type = {
|
||||
"NeoCustom": "gpt_neo",
|
||||
"GPT2Custom": "gpt2",
|
||||
}.get(utils.koboldai_vars.model)
|
||||
|
||||
if not utils.koboldai_vars.model_type:
|
||||
logger.warning(
|
||||
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||
)
|
||||
utils.koboldai_vars.model_type = "gpt_neo"
|
@@ -31,7 +31,7 @@ from modeling import warpers
|
||||
from modeling.warpers import Warper
|
||||
from modeling.stoppers import Stoppers
|
||||
from modeling.post_token_hooks import PostTokenHooks
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
from modeling.inference_models.parents.hf import HFInferenceModel
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
@@ -55,13 +55,13 @@ LOG_SAMPLER_NO_EFFECT = False
|
||||
class HFTorchInferenceModel(HFInferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
lazy_load: bool,
|
||||
low_mem: bool,
|
||||
#model_name: str,
|
||||
#lazy_load: bool,
|
||||
#low_mem: bool,
|
||||
) -> None:
|
||||
super().__init__(model_name)
|
||||
self.lazy_load = lazy_load
|
||||
self.low_mem = low_mem
|
||||
super().__init__()
|
||||
#self.lazy_load = lazy_load
|
||||
#self.low_mem = low_mem
|
||||
|
||||
self.post_token_hooks = [
|
||||
PostTokenHooks.stream_tokens,
|
||||
@@ -211,40 +211,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||
use_core_manipulations.sample = new_sample
|
||||
|
||||
# PEFT Loading. This MUST be done after all save_pretrained calls are
|
||||
# finished on the main model.
|
||||
if utils.args.peft:
|
||||
from peft import PeftModel, PeftConfig
|
||||
local_peft_dir = os.path.join(m_self.get_local_model_path(), "peft")
|
||||
|
||||
# Make PEFT dir if it doesn't exist
|
||||
try:
|
||||
os.makedirs(local_peft_dir)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
peft_local_path = os.path.join(local_peft_dir, utils.args.peft.replace("/", "_"))
|
||||
logger.debug(f"Loading PEFT '{utils.args.peft}', possible local path is '{peft_local_path}'.")
|
||||
|
||||
peft_installed_locally = True
|
||||
possible_peft_locations = [peft_local_path, utils.args.peft]
|
||||
|
||||
for i, location in enumerate(possible_peft_locations):
|
||||
try:
|
||||
m_self.model = PeftModel.from_pretrained(m_self.model, location)
|
||||
logger.debug(f"Loaded PEFT at '{location}'")
|
||||
break
|
||||
except ValueError:
|
||||
peft_installed_locally = False
|
||||
if i == len(possible_peft_locations) - 1:
|
||||
raise RuntimeError(f"Unable to load PeftModel for given name '{utils.args.peft}'. Does it exist?")
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Error while loading PeftModel. Are you using the correct model?")
|
||||
|
||||
if not peft_installed_locally:
|
||||
logger.debug(f"PEFT not saved to models folder; saving to '{peft_local_path}'")
|
||||
m_self.model.save_pretrained(peft_local_path)
|
||||
|
||||
return super()._post_load()
|
||||
|
||||
def _raw_generate(
|
||||
@@ -272,13 +238,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
|
||||
# HEED & BEWARE: All arguments passed to self.model.generate MUST be
|
||||
# kwargs; see https://github.com/huggingface/peft/issues/232. If they
|
||||
# aren't, PeftModel will EXPLODE!!!! But nothing will happen without
|
||||
# a PEFT loaded so it's sneaky.
|
||||
genout = self.model.generate(
|
||||
input_ids=gen_in,
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
max_length=min(
|
||||
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
|
||||
@@ -304,7 +265,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||
tf_kwargs["revision"] = utils.koboldai_vars.revision
|
||||
tf_kwargs["cache_dir"] = "cache"
|
||||
tf_kwargs["trust_remote_code"] = utils.koboldai_vars.trust_remote_code
|
||||
|
||||
# If we have model hints for legacy model, use them rather than fall back.
|
||||
try:
|
@@ -17,7 +17,7 @@ from torch.nn import functional as F
|
||||
os.environ["RWKV_JIT_ON"] = "1"
|
||||
# TODO: Include compiled kernel
|
||||
os.environ["RWKV_CUDA_ON"] = "1"
|
||||
from rwkv.model import RWKV
|
||||
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
@@ -55,13 +55,13 @@ MODEL_FILES = {
|
||||
}
|
||||
|
||||
|
||||
class RWKVInferenceModel(InferenceModel):
|
||||
class model_loader(InferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
#model_name: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
#self.model_name = model_name
|
||||
|
||||
self.post_token_hooks = [
|
||||
PostTokenHooks.stream_tokens,
|
||||
@@ -83,6 +83,23 @@ class RWKVInferenceModel(InferenceModel):
|
||||
)
|
||||
self._old_stopping_criteria = None
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
try:
|
||||
from rwkv.model import RWKV
|
||||
valid = True
|
||||
except:
|
||||
valid = False
|
||||
return valid and "rwkv" in model_name.lower()
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
self.source = model_name
|
||||
requested_parameters = []
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self):
|
||||
return
|
||||
|
||||
|
||||
def _ensure_directory_structure(self) -> None:
|
||||
for path in ["models/rwkv", "models/rwkv/models"]:
|
||||
try:
|
||||
@@ -145,6 +162,7 @@ class RWKVInferenceModel(InferenceModel):
|
||||
# Now we load!
|
||||
|
||||
# TODO: Breakmodel to strat
|
||||
from rwkv.model import RWKV
|
||||
self.model = RWKV(model=model_path, strategy="cuda:0 fp16")
|
||||
|
||||
def _apply_warpers(
|
||||
|
@@ -347,6 +347,28 @@ border-top-right-radius: var(--tabs_rounding);
|
||||
}
|
||||
|
||||
|
||||
.setting_container_model {
|
||||
display: grid;
|
||||
grid-template-areas: "label value"
|
||||
"item item"
|
||||
"minlabel maxlabel";
|
||||
grid-template-rows: 20px 23px 20px;
|
||||
grid-template-columns: auto 30px;
|
||||
row-gap: 0.2em;
|
||||
background-color: var(--setting_background);
|
||||
color: var(--setting_text);
|
||||
border-radius: var(--radius_settings_background);
|
||||
padding: 2px;
|
||||
margin: 2px;
|
||||
width: calc(100%);
|
||||
}
|
||||
|
||||
.setting_container_model .setting_item{
|
||||
font-size: calc(0.93em + var(--font_size_adjustment));
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
|
||||
.setting_minlabel {
|
||||
padding-top: 6px;
|
||||
grid-area: minlabel;
|
||||
@@ -3370,6 +3392,23 @@ textarea {
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes pulse-red {
|
||||
0% {
|
||||
transform: scale(0.95);
|
||||
box-shadow: 0 0 0 0 rgba(255, 0, 0, 0.7);
|
||||
}
|
||||
|
||||
70% {
|
||||
transform: scale(1);
|
||||
box-shadow: 0 0 0 10px rgba(255, 0, 0, 0);
|
||||
}
|
||||
|
||||
100% {
|
||||
transform: scale(0.95);
|
||||
box-shadow: 0 0 0 0 rgba(255, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes pulse-text {
|
||||
0% {
|
||||
filter: blur(3px);
|
||||
@@ -3391,6 +3430,11 @@ textarea {
|
||||
}
|
||||
}
|
||||
|
||||
.input_error {
|
||||
border: 5px solid red !important;
|
||||
box-sizing: border-box !important;
|
||||
}
|
||||
|
||||
.single_pulse {
|
||||
animation: pulse-text 0.5s 1;
|
||||
}
|
||||
|
@@ -15,6 +15,7 @@ socket.on('popup_items', function(data){popup_items(data);});
|
||||
socket.on('popup_breadcrumbs', function(data){popup_breadcrumbs(data);});
|
||||
socket.on('popup_edit_file', function(data){popup_edit_file(data);});
|
||||
socket.on('show_model_menu', function(data){show_model_menu(data);});
|
||||
socket.on('open_model_load_menu', function(data){new_show_model_menu(data);});
|
||||
socket.on('selected_model_info', function(data){selected_model_info(data);});
|
||||
socket.on('oai_engines', function(data){oai_engines(data);});
|
||||
socket.on('buildload', function(data){buildload(data);});
|
||||
@@ -81,6 +82,7 @@ const on_colab = $el("#on_colab").textContent == "true";
|
||||
let story_id = -1;
|
||||
var dirty_chunks = [];
|
||||
var initial_socketio_connection_occured = false;
|
||||
var selected_model_data;
|
||||
|
||||
// Each entry into this array should be an object that looks like:
|
||||
// {class: "class", key: "key", func: callback}
|
||||
@@ -1500,49 +1502,46 @@ function getModelParameterCount(modelName) {
|
||||
return base * multiplier;
|
||||
}
|
||||
|
||||
function show_model_menu(data) {
|
||||
//clear old options
|
||||
document.getElementById("modelkey").classList.add("hidden");
|
||||
document.getElementById("modelkey").value = "";
|
||||
document.getElementById("modelurl").classList.add("hidden");
|
||||
document.getElementById("use_gpu_div").classList.add("hidden");
|
||||
document.getElementById("use_8_bit_div").classList.add("hidden");
|
||||
document.getElementById("modellayers").classList.add("hidden");
|
||||
document.getElementById("oaimodel").classList.add("hidden");
|
||||
var model_layer_bars = document.getElementById('model_layer_bars');
|
||||
while (model_layer_bars.firstChild) {
|
||||
model_layer_bars.removeChild(model_layer_bars.firstChild);
|
||||
function new_show_model_menu(data) {
|
||||
//clear out the loadmodelsettings
|
||||
var loadmodelsettings = document.getElementById('loadmodelsettings')
|
||||
while (loadmodelsettings.firstChild) {
|
||||
loadmodelsettings.removeChild(loadmodelsettings.firstChild);
|
||||
}
|
||||
document.getElementById("modelplugin").classList.add("hidden");
|
||||
var accept = document.getElementById("btn_loadmodelaccept");
|
||||
accept.disabled = false;
|
||||
|
||||
//clear out the breadcrumbs
|
||||
var breadcrumbs = document.getElementById('loadmodellistbreadcrumbs')
|
||||
while (breadcrumbs.firstChild) {
|
||||
breadcrumbs.removeChild(breadcrumbs.firstChild);
|
||||
}
|
||||
//add breadcrumbs
|
||||
//console.log(data.breadcrumbs);
|
||||
for (item of data.breadcrumbs) {
|
||||
var button = document.createElement("button");
|
||||
button.classList.add("breadcrumbitem");
|
||||
button.setAttribute("model", data.menu);
|
||||
button.setAttribute("folder", item[0]);
|
||||
button.textContent = item[1];
|
||||
button.onclick = function () {
|
||||
socket.emit('select_model', {'menu': "", 'model': this.getAttribute("model"), 'path': this.getAttribute("folder")});
|
||||
};
|
||||
breadcrumbs.append(button);
|
||||
var span = document.createElement("span");
|
||||
span.textContent = "\\";
|
||||
breadcrumbs.append(span);
|
||||
}
|
||||
|
||||
//add breadcrumbs
|
||||
if ('breadcrumbs' in data) {
|
||||
for (item of data.breadcrumbs) {
|
||||
var button = document.createElement("button");
|
||||
button.classList.add("breadcrumbitem");
|
||||
button.setAttribute("model", data.menu);
|
||||
button.setAttribute("folder", item[0]);
|
||||
button.textContent = item[1];
|
||||
button.onclick = function () {
|
||||
socket.emit('select_model', {'menu': "", 'name': this.getAttribute("model"), 'path': this.getAttribute("folder")});
|
||||
};
|
||||
breadcrumbs.append(button);
|
||||
var span = document.createElement("span");
|
||||
span.textContent = "\\";
|
||||
breadcrumbs.append(span);
|
||||
}
|
||||
}
|
||||
//clear out the items
|
||||
var model_list = document.getElementById('loadmodellistcontent')
|
||||
while (model_list.firstChild) {
|
||||
model_list.removeChild(model_list.firstChild);
|
||||
}
|
||||
//add items
|
||||
for (item of data.data) {
|
||||
for (item of data.items) {
|
||||
var list_item = document.createElement("span");
|
||||
list_item.classList.add("model_item");
|
||||
|
||||
@@ -1564,10 +1563,27 @@ function show_model_menu(data) {
|
||||
//create the actual item
|
||||
var popup_item = document.createElement("span");
|
||||
popup_item.classList.add("model");
|
||||
popup_item.setAttribute("display_name", item.label);
|
||||
popup_item.id = item.name;
|
||||
for (const key in item) {
|
||||
if (key == "name") {
|
||||
popup_item.id = item[key];
|
||||
}
|
||||
popup_item.setAttribute(key, item[key]);
|
||||
}
|
||||
|
||||
popup_item.onclick = function() {
|
||||
var attributes = this.attributes;
|
||||
var obj = {};
|
||||
|
||||
for (var i = 0, len = attributes.length; i < len; i++) {
|
||||
obj[attributes[i].name] = attributes[i].value;
|
||||
}
|
||||
//put the model data on the accept button so we can send it to the server when you accept
|
||||
var accept = document.getElementById("popup_accept");
|
||||
selected_model_data = obj;
|
||||
//send the data to the server so it can figure out what data we need from the user for the model
|
||||
socket.emit('select_model', obj);
|
||||
}
|
||||
|
||||
popup_item.setAttribute("Menu", data.menu)
|
||||
//name text
|
||||
var text = document.createElement("span");
|
||||
text.style="grid-area: item;";
|
||||
@@ -1615,241 +1631,223 @@ function show_model_menu(data) {
|
||||
});
|
||||
})();
|
||||
|
||||
popup_item.onclick = function () {
|
||||
var accept = document.getElementById("btn_loadmodelaccept");
|
||||
accept.classList.add("disabled");
|
||||
socket.emit("select_model", {"model": this.id, "menu": this.getAttribute("Menu"), "display_name": this.getAttribute("display_name")});
|
||||
var model_list = document.getElementById('loadmodellistcontent').getElementsByClassName("selected");
|
||||
for (model of model_list) {
|
||||
model.classList.remove("selected");
|
||||
}
|
||||
this.classList.add("selected");
|
||||
accept.setAttribute("selected_model", this.id);
|
||||
accept.setAttribute("menu", this.getAttribute("Menu"));
|
||||
accept.setAttribute("display_name", this.getAttribute("display_name"));
|
||||
};
|
||||
list_item.append(popup_item);
|
||||
|
||||
|
||||
model_list.append(list_item);
|
||||
}
|
||||
var accept = document.getElementById("btn_loadmodelaccept");
|
||||
accept.disabled = true;
|
||||
|
||||
//finally, if they selected the custom hugging face menu we show the input box
|
||||
if (data['menu'] == "customhuggingface") {
|
||||
document.getElementById("custommodelname").classList.remove("hidden");
|
||||
} else {
|
||||
document.getElementById("custommodelname").classList.add("hidden");
|
||||
}
|
||||
|
||||
|
||||
// detect if we are in a model selection screen and show the reference
|
||||
var refelement = document.getElementById("modelspecifier");
|
||||
var check = document.getElementById("mainmenu");
|
||||
if (check) {
|
||||
refelement.classList.remove("hidden");
|
||||
} else {
|
||||
refelement.classList.add("hidden");
|
||||
}
|
||||
|
||||
openPopup("load-model");
|
||||
|
||||
}
|
||||
|
||||
|
||||
function selected_model_info(data) {
|
||||
//clear out the loadmodelsettings
|
||||
var loadmodelsettings = document.getElementById('loadmodelsettings')
|
||||
while (loadmodelsettings.firstChild) {
|
||||
loadmodelsettings.removeChild(loadmodelsettings.firstChild);
|
||||
}
|
||||
var accept = document.getElementById("btn_loadmodelaccept");
|
||||
//hide or unhide key
|
||||
if (data.key) {
|
||||
document.getElementById("modelkey").classList.remove("hidden");
|
||||
document.getElementById("modelkey").value = data.key_value;
|
||||
} else {
|
||||
document.getElementById("modelkey").classList.add("hidden");
|
||||
document.getElementById("modelkey").value = "";
|
||||
}
|
||||
//hide or unhide URL
|
||||
if (data.url) {
|
||||
document.getElementById("modelurl").classList.remove("hidden");
|
||||
} else {
|
||||
document.getElementById("modelurl").classList.add("hidden");
|
||||
}
|
||||
|
||||
//hide or unhide 8 bit mode
|
||||
if (data.bit_8_available) {
|
||||
document.getElementById("use_8_bit_div").classList.remove("hidden");
|
||||
} else {
|
||||
document.getElementById("use_8_bit_div").classList.add("hidden");
|
||||
document.getElementById("use_8_bit").checked = false;
|
||||
}
|
||||
|
||||
//default URL loading
|
||||
if (data.default_url != null) {
|
||||
document.getElementById("modelurl").value = data.default_url;
|
||||
}
|
||||
|
||||
//change model loading on url if needed
|
||||
if (data.models_on_url) {
|
||||
document.getElementById("modelurl").onchange = function () {socket.emit('get_cluster_models', {'model': document.getElementById('btn_loadmodelaccept').getAttribute('selected_model'), 'key': document.getElementById("modelkey").value, 'url': this.value});};
|
||||
document.getElementById("modelkey").onchange = function () {socket.emit('get_cluster_models', {'model': document.getElementById('btn_loadmodelaccept').getAttribute('selected_model'), 'key': this.value, 'url': document.getElementById("modelurl").value});};
|
||||
} else {
|
||||
document.getElementById("modelkey").ochange = function () {socket.emit('OAI_Key_Update', {'model': document.getElementById('btn_loadmodelaccept').getAttribute('selected_model'), 'key': this.value});};
|
||||
document.getElementById("modelurl").ochange = null;
|
||||
}
|
||||
|
||||
//show model select for APIs
|
||||
if (data.show_online_model_select) {
|
||||
document.getElementById("oaimodel").classList.remove("hidden");
|
||||
} else {
|
||||
document.getElementById("oaimodel").classList.add("hidden");
|
||||
}
|
||||
|
||||
//Multiple Model Select?
|
||||
if (data.multi_online_models) {
|
||||
document.getElementById("oaimodel").setAttribute("multiple", "");
|
||||
document.getElementById("oaimodel").options[0].textContent = "All"
|
||||
} else {
|
||||
document.getElementById("oaimodel").removeAttribute("multiple");
|
||||
document.getElementById("oaimodel").options[0].textContent = "Select Model(s)"
|
||||
}
|
||||
|
||||
//hide or unhide the use gpu checkbox
|
||||
if (data.gpu) {
|
||||
document.getElementById("use_gpu_div").classList.remove("hidden");
|
||||
} else {
|
||||
document.getElementById("use_gpu_div").classList.add("hidden");
|
||||
}
|
||||
//setup breakmodel
|
||||
if (data.breakmodel) {
|
||||
document.getElementById("modellayers").classList.remove("hidden");
|
||||
//setup model layer count
|
||||
document.getElementById("gpu_layers_current").textContent = data.break_values.reduce((a, b) => a + b, 0);
|
||||
document.getElementById("gpu_layers_max").textContent = data.layer_count;
|
||||
document.getElementById("gpu_count").value = data.gpu_count;
|
||||
|
||||
//create the gpu load bars
|
||||
var model_layer_bars = document.getElementById('model_layer_bars');
|
||||
while (model_layer_bars.firstChild) {
|
||||
model_layer_bars.removeChild(model_layer_bars.firstChild);
|
||||
}
|
||||
|
||||
//Add the bars
|
||||
for (let i = 0; i < data.gpu_names.length; i++) {
|
||||
var div = document.createElement("div");
|
||||
div.classList.add("model_setting_container");
|
||||
//build GPU text
|
||||
var span = document.createElement("span");
|
||||
span.classList.add("model_setting_label");
|
||||
span.textContent = "GPU " + i + " " + data.gpu_names[i] + ": "
|
||||
//build layer count box
|
||||
var input = document.createElement("input");
|
||||
input.classList.add("model_setting_value");
|
||||
input.classList.add("setting_value");
|
||||
input.inputmode = "numeric";
|
||||
input.id = "gpu_layers_box_"+i;
|
||||
input.value = data.break_values[i];
|
||||
input.onblur = function () {
|
||||
document.getElementById(this.id.replace("_box", "")).value = this.value;
|
||||
update_gpu_layers();
|
||||
}
|
||||
span.append(input);
|
||||
div.append(span);
|
||||
//build layer count slider
|
||||
var input = document.createElement("input");
|
||||
input.classList.add("model_setting_item");
|
||||
input.type = "range";
|
||||
input.min = 0;
|
||||
input.max = data.layer_count;
|
||||
input.step = 1;
|
||||
input.value = data.break_values[i];
|
||||
input.id = "gpu_layers_" + i;
|
||||
input.onchange = function () {
|
||||
document.getElementById(this.id.replace("gpu_layers", "gpu_layers_box")).value = this.value;
|
||||
update_gpu_layers();
|
||||
}
|
||||
div.append(input);
|
||||
//build slider bar #s
|
||||
//min
|
||||
var span = document.createElement("span");
|
||||
span.classList.add("model_setting_minlabel");
|
||||
var span2 = document.createElement("span");
|
||||
span2.style="top: -4px; position: relative;";
|
||||
span2.textContent = 0;
|
||||
span.append(span2);
|
||||
div.append(span);
|
||||
//max
|
||||
var span = document.createElement("span");
|
||||
span.classList.add("model_setting_maxlabel");
|
||||
var span2 = document.createElement("span");
|
||||
span2.style="top: -4px; position: relative;";
|
||||
span2.textContent = data.layer_count;
|
||||
span.append(span2);
|
||||
div.append(span);
|
||||
|
||||
model_layer_bars.append(div);
|
||||
}
|
||||
|
||||
//add the disk layers
|
||||
if (data.disk_break) {
|
||||
var div = document.createElement("div");
|
||||
div.classList.add("model_setting_container");
|
||||
//build GPU text
|
||||
var span = document.createElement("span");
|
||||
span.classList.add("model_setting_label");
|
||||
span.textContent = "Disk cache: "
|
||||
//build layer count box
|
||||
var input = document.createElement("input");
|
||||
input.classList.add("model_setting_value");
|
||||
input.classList.add("setting_value");
|
||||
input.inputmode = "numeric";
|
||||
input.id = "disk_layers_box";
|
||||
input.value = data.disk_break_value;
|
||||
input.onblur = function () {
|
||||
document.getElementById(this.id.replace("_box", "")).value = this.value;
|
||||
update_gpu_layers();
|
||||
}
|
||||
span.append(input);
|
||||
div.append(span);
|
||||
//build layer count slider
|
||||
var input = document.createElement("input");
|
||||
input.classList.add("model_setting_item");
|
||||
input.type = "range";
|
||||
input.min = 0;
|
||||
input.max = data.layer_count;
|
||||
input.step = 1;
|
||||
input.value = data.disk_break_value;
|
||||
input.id = "disk_layers";
|
||||
input.onchange = function () {
|
||||
document.getElementById(this.id+"_box").value = this.value;
|
||||
update_gpu_layers();
|
||||
}
|
||||
div.append(input);
|
||||
//build slider bar #s
|
||||
//min
|
||||
var span = document.createElement("span");
|
||||
span.classList.add("model_setting_minlabel");
|
||||
var span2 = document.createElement("span");
|
||||
span2.style="top: -4px; position: relative;";
|
||||
span2.textContent = 0;
|
||||
span.append(span2);
|
||||
div.append(span);
|
||||
//max
|
||||
var span = document.createElement("span");
|
||||
span.classList.add("model_setting_maxlabel");
|
||||
var span2 = document.createElement("span");
|
||||
span2.style="top: -4px; position: relative;";
|
||||
span2.textContent = data.layer_count;
|
||||
span.append(span2);
|
||||
div.append(span);
|
||||
}
|
||||
|
||||
model_layer_bars.append(div);
|
||||
|
||||
update_gpu_layers();
|
||||
} else {
|
||||
document.getElementById("modellayers").classList.add("hidden");
|
||||
accept.classList.remove("disabled");
|
||||
}
|
||||
accept.disabled = false;
|
||||
|
||||
modelplugin = document.getElementById("modelplugin");
|
||||
modelplugin.classList.remove("hidden");
|
||||
modelplugin.onchange = function () {
|
||||
for (const area of document.getElementsByClassName("model_plugin_settings_area")) {
|
||||
area.classList.add("hidden");
|
||||
}
|
||||
document.getElementById(this.value + "_settings_area").classList.remove("hidden");
|
||||
}
|
||||
//create the content
|
||||
for (const [loader, items] of Object.entries(data)) {
|
||||
model_area = document.createElement("DIV");
|
||||
model_area.id = loader + "_settings_area";
|
||||
model_area.classList.add("model_plugin_settings_area");
|
||||
model_area.classList.add("hidden");
|
||||
modelpluginoption = document.createElement("option");
|
||||
modelpluginoption.innerText = loader;
|
||||
modelpluginoption.value = loader;
|
||||
modelplugin.append(modelpluginoption);
|
||||
|
||||
for (item of items) {
|
||||
let new_setting = document.getElementById('blank_model_settings').cloneNode(true);
|
||||
new_setting.id = loader;
|
||||
new_setting.classList.remove("hidden");
|
||||
new_setting.querySelector('#blank_model_settings_label').innerText = item['label'];
|
||||
new_setting.querySelector('#blank_model_settings_tooltip').setAttribute("tooltip", item['tooltip']);
|
||||
|
||||
onchange_event = function () {
|
||||
//get check value:
|
||||
if ('sum' in this.check_data) {
|
||||
check_value = 0
|
||||
for (const temp of this.check_data['sum']) {
|
||||
if (document.getElementById(this.id.split("|")[0] +"|" + temp + "_value")) {
|
||||
check_value += parseInt(document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
check_value = this.value
|
||||
}
|
||||
if (this.check_data['check'] == "=") {
|
||||
valid = (check_value == this.check_data['value']);
|
||||
} else if (this.check_data['check'] == "!=") {
|
||||
valid = (check_value != this.check_data['value']);
|
||||
} else if (this.check_data['check'] == ">=") {
|
||||
valid = (check_value >= this.check_data['value']);
|
||||
} else if (this.check_data['check'] == "<=") {
|
||||
valid = (check_value <= this.check_data['value']);
|
||||
} else if (this.check_data['check'] == "<=") {
|
||||
valid = (check_value > this.check_data['value']);
|
||||
} else if (this.check_data['check'] == "<=") {
|
||||
valid = (check_value < this.check_data['value']);
|
||||
}
|
||||
if (valid) {
|
||||
//if we are supposed to refresh when this value changes we'll resubmit
|
||||
if (this.getAttribute("refresh_model_inputs") == "true") {
|
||||
console.log("resubmit");
|
||||
}
|
||||
if ('sum' in this.check_data) {
|
||||
for (const temp of this.check_data['sum']) {
|
||||
if (document.getElementById(this.id.split("|")[0] +"|" + temp + "_value")) {
|
||||
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").classList.remove('input_error');
|
||||
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").removeAttribute("tooltip");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.closest(".setting_container_model").classList.remove('input_error');
|
||||
this.closest(".setting_container_model").removeAttribute("tooltip");
|
||||
}
|
||||
var accept = document.getElementById("btn_loadmodelaccept");
|
||||
if (document.getElementsByClassName("input_error").length)
|
||||
accept.disabled = true;
|
||||
} else {
|
||||
if ('sum' in this.check_data) {
|
||||
for (const temp of this.check_data['sum']) {
|
||||
if (document.getElementById(this.id.split("|")[0] +"|" + temp + "_value")) {
|
||||
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").classList.add('input_error');
|
||||
document.getElementById(this.id.split("|")[0] +"|" + temp + "_value").closest(".setting_container_model").setAttribute("tooltip", this.check_data['check_message']);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.closest(".setting_container_model").classList.add('input_error');
|
||||
this.closest(".setting_container_model").setAttribute("tooltip", this.check_data['check_message']);
|
||||
}
|
||||
}
|
||||
var accept = document.getElementById("btn_loadmodelaccept");
|
||||
if (document.getElementsByClassName("input_error").length > 0) {
|
||||
accept.classList.add("disabled");
|
||||
accept.disabled = true;
|
||||
} else {
|
||||
accept.classList.remove("disabled");
|
||||
accept.disabled = false;
|
||||
}
|
||||
|
||||
}
|
||||
if (item['uitype'] == "slider") {
|
||||
var slider_number = new_setting.querySelector('#blank_model_settings_value_slider_number');
|
||||
slider_number.value = item['default'];
|
||||
slider_number.id = loader + "|" + item['id'] + "_value_text";
|
||||
slider_number.onchange = function() { document.getElementById(this.id.replace("_text", "")).value = this.value;};
|
||||
|
||||
var slider = new_setting.querySelector('#blank_model_settings_slider');
|
||||
slider.value = item['default'];
|
||||
slider.min = item['min'];
|
||||
slider.max = item['max'];
|
||||
slider.id = loader + "|" + item['id'] + "_value";
|
||||
if ('check' in item) {
|
||||
slider.check_data = item['check'];
|
||||
slider_number.check_data = item['check'];
|
||||
} else {
|
||||
slider.check_data = null;
|
||||
slider_number.check_data = null;
|
||||
}
|
||||
slider.oninput = function() { document.getElementById(this.id+"_text").value = this.value;};
|
||||
slider.onchange = onchange_event;
|
||||
slider.setAttribute("refresh_model_inputs", item['refresh_model_inputs']);
|
||||
new_setting.querySelector('#blank_model_settings_min_label').innerText = item['min'];
|
||||
new_setting.querySelector('#blank_model_settings_max_label').innerText = item['max'];
|
||||
slider.onchange();
|
||||
} else {
|
||||
new_setting.querySelector('#blank_model_settings_slider').classList.add("hidden");
|
||||
}
|
||||
if (item['uitype'] == "toggle") {
|
||||
var toggle = new_setting.querySelector('#blank_model_settings_toggle');
|
||||
toggle.id = loader + "|" + item['id'] + "_value";
|
||||
toggle.checked = item['default'];
|
||||
toggle.onchange = onchange_event;
|
||||
toggle.setAttribute("refresh_model_inputs", item['refresh_model_inputs']);
|
||||
if ('check' in item) {
|
||||
toggle.check_data = item['check'];
|
||||
} else {
|
||||
toggle.check_data = null;
|
||||
}
|
||||
toggle.onchange();
|
||||
} else {
|
||||
new_setting.querySelector('#blank_model_settings_checkbox_container').classList.add("hidden");
|
||||
new_setting.querySelector('#blank_model_settings_toggle').classList.add("hidden");
|
||||
}
|
||||
if (item['uitype'] == "dropdown") {
|
||||
var select_element = new_setting.querySelector('#blank_model_settings_dropdown');
|
||||
select_element.id = loader + "|" + item['id'] + "_value";
|
||||
for (const dropdown_value of item['children']) {
|
||||
new_option = document.createElement("option");
|
||||
new_option.value = dropdown_value['value'];
|
||||
new_option.innerText = dropdown_value['text'];
|
||||
select_element.append(new_option);
|
||||
}
|
||||
select_element.value = item['default'];
|
||||
select_element.onchange = onchange_event;
|
||||
select_element.setAttribute("refresh_model_inputs", item['refresh_model_inputs']);
|
||||
if ('check' in item) {
|
||||
select_element.check_data = item['check'];
|
||||
} else {
|
||||
select_element.check_data = null;
|
||||
}
|
||||
select_element.onchange();
|
||||
} else {
|
||||
new_setting.querySelector('#blank_model_settings_dropdown').classList.add("hidden");
|
||||
}
|
||||
if (item['uitype'] == "password") {
|
||||
var password_item = new_setting.querySelector('#blank_model_settings_password');
|
||||
password_item.id = loader + "|" + item['id'] + "_value";
|
||||
password_item.value = item['default'];
|
||||
password_item.onchange = onchange_event;
|
||||
password_item.setAttribute("refresh_model_inputs", item['refresh_model_inputs']);
|
||||
if ('check' in item) {
|
||||
password_item.check_data = item['check'];
|
||||
} else {
|
||||
password_item.check_data = null;
|
||||
}
|
||||
password_item.onchange();
|
||||
} else {
|
||||
new_setting.querySelector('#blank_model_settings_password').classList.add("hidden");
|
||||
}
|
||||
if (item['uitype'] == "text") {
|
||||
var text_item = new_setting.querySelector('#blank_model_settings_text');
|
||||
text_item.id = loader + "|" + item['id'] + "_value";
|
||||
text_item.value = item['default'];
|
||||
text_item.onchange = onchange_event;
|
||||
text_item.setAttribute("refresh_model_inputs", item['refresh_model_inputs']);
|
||||
if ('check' in item) {
|
||||
text_item.check_data = item['check'];
|
||||
} else {
|
||||
text_item.check_data = null;
|
||||
}
|
||||
text_item.onchange();
|
||||
} else {
|
||||
new_setting.querySelector('#blank_model_settings_text').classList.add("hidden");
|
||||
}
|
||||
|
||||
model_area.append(new_setting);
|
||||
loadmodelsettings.append(model_area);
|
||||
}
|
||||
}
|
||||
|
||||
//unhide the first plugin settings
|
||||
console.log(document.getElementById("modelplugin").value + "_settings_area");
|
||||
if (document.getElementById(document.getElementById("modelplugin").value + "_settings_area")) {
|
||||
document.getElementById(document.getElementById("modelplugin").value + "_settings_area").classList.remove("hidden");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1877,42 +1875,16 @@ function update_gpu_layers() {
|
||||
|
||||
function load_model() {
|
||||
var accept = document.getElementById('btn_loadmodelaccept');
|
||||
gpu_layers = []
|
||||
disk_layers = 0;
|
||||
if (!(document.getElementById("modellayers").classList.contains("hidden"))) {
|
||||
for (let i=0; i < document.getElementById("gpu_count").value; i++) {
|
||||
gpu_layers.push(document.getElementById("gpu_layers_"+i).value);
|
||||
}
|
||||
if (document.getElementById("disk_layers")) {
|
||||
disk_layers = document.getElementById("disk_layers").value;
|
||||
}
|
||||
}
|
||||
//Need to do different stuff with custom models
|
||||
if ((accept.getAttribute('menu') == 'GPT2Custom') || (accept.getAttribute('menu') == 'NeoCustom')) {
|
||||
var model = document.getElementById("btn_loadmodelaccept").getAttribute("menu");
|
||||
var path = document.getElementById("btn_loadmodelaccept").getAttribute("display_name");
|
||||
} else {
|
||||
var model = document.getElementById("btn_loadmodelaccept").getAttribute("selected_model");
|
||||
var path = "";
|
||||
}
|
||||
settings_area = document.getElementById(document.getElementById("modelplugin").value + "_settings_area");
|
||||
|
||||
let selected_models = [];
|
||||
for (item of document.getElementById("oaimodel").selectedOptions) {
|
||||
selected_models.push(item.value);
|
||||
}
|
||||
if (selected_models == ['']) {
|
||||
|
||||
selected_models = [];
|
||||
} else if (selected_models.length == 1) {
|
||||
selected_models = selected_models[0];
|
||||
//get an object of all the input settings from the user
|
||||
data = {}
|
||||
for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) {
|
||||
data[element.id.split("|")[1].replace("_value", "")] = element.value;
|
||||
}
|
||||
data = {...data, ...selected_model_data}
|
||||
|
||||
message = {'model': model, 'path': path, 'use_gpu': document.getElementById("use_gpu").checked,
|
||||
'key': document.getElementById('modelkey').value, 'gpu_layers': gpu_layers.join(),
|
||||
'disk_layers': disk_layers, 'url': document.getElementById("modelurl").value,
|
||||
'online_model': selected_models,
|
||||
'use_8_bit': document.getElementById('use_8_bit').checked};
|
||||
socket.emit("load_model", message);
|
||||
socket.emit("load_model", data);
|
||||
closePopups();
|
||||
}
|
||||
|
||||
|
@@ -46,35 +46,11 @@
|
||||
<div id="model-spec-usage">Usage (VRAM)</div>
|
||||
</span>
|
||||
</span>
|
||||
<div id="loadmodellistbreadcrumbs">
|
||||
|
||||
</div>
|
||||
<div id="loadmodellistbreadcrumbs"></div>
|
||||
<div id="loadmodellistcontent" class="popup_list_area"></div>
|
||||
<div id="loadmodelplugin" class="popup_load_cancel loadmodelsettings"><select id="modelplugin" class="settings_select hidden"></select></div>
|
||||
<div id="loadmodelsettings" class="popup_load_cancel loadmodelsettings"></div>
|
||||
<div class="popup_load_cancel">
|
||||
<div>
|
||||
<input class="hidden fullwidth" type="text" placeholder="key" id="modelkey" onchange="socket.emit('OAI_Key_Update', {'model': document.getElementById('btn_loadmodelaccept').getAttribute('selected_model'), 'key': this.value});">
|
||||
<input class="hidden fullwidth" type="text" placeholder="Enter the URL of the server (For example a trycloudflare link)" id="modelurl" onchange="check_enable_model_load()">
|
||||
<input class="hidden fullwidth" type="text" placeholder="Hugging Face Model Name" id="custommodelname" menu="" onblur="socket.emit('get_model_info', this.value);
|
||||
document.getElementById('btn_loadmodelaccept').setAttribute('selected_model', this.value);
|
||||
">
|
||||
<select class="hidden fullwidth settings_select" id="oaimodel"><option value="">Select OAI Model</option></select>
|
||||
</div>
|
||||
<div class="hidden" id=modellayers>
|
||||
<div class="justifyleft">
|
||||
GPU/Disk Layers<span class="material-icons-outlined helpicon" tooltip="Number of layers to assign to GPUs and to disk cache. Remaining layers will be put into CPU RAM.">help_icon</span>
|
||||
</div>
|
||||
<div class="justifyright"><span id="gpu_layers_current">0</span>/<span id="gpu_layers_max">0</span></div>
|
||||
<div id=model_layer_bars style="color: white"></div>
|
||||
<input type=hidden id='gpu_count' value=0/>
|
||||
</div>
|
||||
<div class="box flex-push-right hidden" id=use_gpu_div>
|
||||
<input type="checkbox" data-toggle="toggle" data-onstyle="success" id="use_gpu" checked>
|
||||
<div class="box-label">Use GPU</div>
|
||||
</div>
|
||||
<div class="box flex-push-right hidden" id=use_8_bit_div onclick="set_8_bit_mode()">
|
||||
<input type="checkbox" data-toggle="toggle" data-onstyle="success" id="use_8_bit" checked>
|
||||
<div class="box-label">Use 8 bit mode</div>
|
||||
</div>
|
||||
<button type="button" class="btn popup_load_cancel_button action_button disabled" onclick="load_model()" id="btn_loadmodelaccept" disabled>Load</button>
|
||||
<button type="button" class="btn popup_load_cancel_button" onclick='closePopups();' id="btn_loadmodelclose">Cancel</button>
|
||||
</div>
|
||||
|
@@ -154,3 +154,22 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!---------------- Model Settings ---------------------->
|
||||
<div id="blank_model_settings" class="setting_container_model">
|
||||
<span class="setting_label">
|
||||
<span id="blank_model_settings_label">: </span><span id="blank_model_settings_tooltip" class="helpicon material-icons-outlined" style="text-align: left;" tooltip="">help_icon</span>
|
||||
</span>
|
||||
<input autocomplete="off" class="setting_value" id="blank_model_settings_value_slider_number">
|
||||
<span class="setting_item">
|
||||
<input type="range" id="blank_model_settings_slider" class="setting_item_input blank_model_settings_input model_settings_input">
|
||||
<span id="blank_model_settings_checkbox_container">
|
||||
<input type=checkbox id="blank_model_settings_toggle" class="setting_item_input blank_model_settings_input model_settings_input" data-size="mini" data-onstyle="success" data-toggle="toggle">
|
||||
</span>
|
||||
<select id="blank_model_settings_dropdown" class="settings_select blank_model_settings_input model_settings_input"></select>
|
||||
<input type=password id="blank_model_settings_password" class="settings_select blank_model_settings_input model_settings_input">
|
||||
<input id="blank_model_settings_text" class="settings_select blank_model_settings_input model_settings_input">
|
||||
</span>
|
||||
<span class="setting_minlabel"><span style="position: relative;" id="blank_model_settings_min_label"></span></span>
|
||||
<span class="setting_maxlabel"><span style="position: relative;" id="blank_model_settings_max_label"></span></span>
|
||||
</span>
|
||||
</div>
|
Reference in New Issue
Block a user