mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add sliders for exllama context size and related methods
This commit is contained in:
@@ -50,7 +50,7 @@ def load_model_gptq_settings(path):
|
||||
gptq_model = False
|
||||
gptq_file = False
|
||||
|
||||
gptq_legacy_files = glob.glob(os.path.join(path, "4bit*.safetensors"))
|
||||
gptq_legacy_files = glob.glob(os.path.join(path, "*4bit*.safetensors"))
|
||||
if "gptq_bits" in js:
|
||||
gptq_model = True
|
||||
gptq_file = os.path.join(path, "model.safetensors")
|
||||
@@ -58,7 +58,7 @@ def load_model_gptq_settings(path):
|
||||
gptq_model = True
|
||||
gptq_file = gptq_legacy_files[0]
|
||||
fname = Path(gptq_file).parts[-1]
|
||||
g = re.findall("^(?:4bit)(?:-)(\\d+)(?:g-?)", fname)
|
||||
g = re.findall("(?:4bit)(?:-)(\\d+)(?:g-?)", fname)
|
||||
|
||||
return gptq_model, gptq_file
|
||||
|
||||
@@ -113,11 +113,6 @@ class model_backend(InferenceModel):
|
||||
if not config and os.path.exists("models/{}".format(model_name.replace('/', '_'))):
|
||||
config = ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json"))
|
||||
|
||||
if config and "superhot" in model_name.lower():
|
||||
# Set compress_pos_emb factor
|
||||
config.max_seq_len = 8192
|
||||
config.compress_pos_emb = 4.0
|
||||
|
||||
return config
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
@@ -366,6 +361,51 @@ class model_backend(InferenceModel):
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
|
||||
requested_parameters.append({
|
||||
"uitype": "slider",
|
||||
"unit": "int",
|
||||
"label": "Maximum Context",
|
||||
"id": "max_ctx",
|
||||
"min": 2048,
|
||||
"max": 16384,
|
||||
"step": 512,
|
||||
"default": 2048,
|
||||
"tooltip": "The maximum context size the model supports",
|
||||
"menu_path": "Configuration",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
|
||||
requested_parameters.append({
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "Embedding Compression",
|
||||
"id": "compress_emb",
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 0.25,
|
||||
"default": 1,
|
||||
"tooltip": "If the model requires compressed embeddings, set them here",
|
||||
"menu_path": "Configuration",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
|
||||
requested_parameters.append({
|
||||
"uitype": "slider",
|
||||
"unit": "float",
|
||||
"label": "NTK alpha",
|
||||
"id": "ntk_alpha",
|
||||
"min": 1,
|
||||
"max": 32,
|
||||
"step": 0.25,
|
||||
"default": 1,
|
||||
"tooltip": "NTK alpha value",
|
||||
"menu_path": "Configuration",
|
||||
"extra_classes": "",
|
||||
"refresh_model_inputs": False
|
||||
})
|
||||
|
||||
return requested_parameters
|
||||
|
||||
def set_input_parameters(self, parameters):
|
||||
@@ -387,6 +427,10 @@ class model_backend(InferenceModel):
|
||||
self.model_config.device_map.lm_head = "cuda:0"
|
||||
self.model_config.device_map.norm = "cuda:0"
|
||||
|
||||
self.model_config.max_seq_len = parameters["max_ctx"]
|
||||
self.model_config.compress_pos_emb = parameters["compress_emb"]
|
||||
self.model_config.alpha_value = parameters["ntk_alpha"]
|
||||
|
||||
# Disable half2 for HIP
|
||||
self.model_config.rmsnorm_no_half2 = bool(torch.version.hip)
|
||||
self.model_config.rope_no_half2 = bool(torch.version.hip)
|
||||
|
@@ -56,7 +56,7 @@ def load_model_gptq_settings(path):
|
||||
gptq_file = False
|
||||
gptq_version = -1
|
||||
|
||||
gptq_legacy_files = glob.glob(os.path.join(path, "4bit*.pt")) + glob.glob(os.path.join(path, "4bit*.safetensors"))
|
||||
gptq_legacy_files = glob.glob(os.path.join(path, "*4bit*.pt")) + glob.glob(os.path.join(path, "*4bit*.safetensors"))
|
||||
if "gptq_bits" in js:
|
||||
gptq_model = True
|
||||
gptq_bits = js["gptq_bits"]
|
||||
@@ -70,7 +70,7 @@ def load_model_gptq_settings(path):
|
||||
gptq_bits = 4
|
||||
gptq_file = gptq_legacy_files[0]
|
||||
fname = Path(gptq_file).parts[-1]
|
||||
g = re.findall("^(?:4bit)(?:-)(\\d+)(?:g-?)", fname)
|
||||
g = re.findall("(?:4bit)(?:-)(\\d+)(?:g-?)", fname)
|
||||
gptq_groupsize = int(g[0]) if g else -1
|
||||
gptq_version = -1
|
||||
|
||||
|
Reference in New Issue
Block a user