Add sliders for exllama context size and related methods

This commit is contained in:
0cc4m
2023-07-23 07:11:28 +02:00
parent 58908ab846
commit 748e5ef318
2 changed files with 53 additions and 9 deletions

View File

@@ -50,7 +50,7 @@ def load_model_gptq_settings(path):
gptq_model = False gptq_model = False
gptq_file = False gptq_file = False
gptq_legacy_files = glob.glob(os.path.join(path, "4bit*.safetensors")) gptq_legacy_files = glob.glob(os.path.join(path, "*4bit*.safetensors"))
if "gptq_bits" in js: if "gptq_bits" in js:
gptq_model = True gptq_model = True
gptq_file = os.path.join(path, "model.safetensors") gptq_file = os.path.join(path, "model.safetensors")
@@ -58,7 +58,7 @@ def load_model_gptq_settings(path):
gptq_model = True gptq_model = True
gptq_file = gptq_legacy_files[0] gptq_file = gptq_legacy_files[0]
fname = Path(gptq_file).parts[-1] fname = Path(gptq_file).parts[-1]
g = re.findall("^(?:4bit)(?:-)(\\d+)(?:g-?)", fname) g = re.findall("(?:4bit)(?:-)(\\d+)(?:g-?)", fname)
return gptq_model, gptq_file return gptq_model, gptq_file
@@ -113,11 +113,6 @@ class model_backend(InferenceModel):
if not config and os.path.exists("models/{}".format(model_name.replace('/', '_'))): if not config and os.path.exists("models/{}".format(model_name.replace('/', '_'))):
config = ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json")) config = ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json"))
if config and "superhot" in model_name.lower():
# Set compress_pos_emb factor
config.max_seq_len = 8192
config.compress_pos_emb = 4.0
return config return config
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
@@ -366,6 +361,51 @@ class model_backend(InferenceModel):
"refresh_model_inputs": False "refresh_model_inputs": False
}) })
requested_parameters.append({
"uitype": "slider",
"unit": "int",
"label": "Maximum Context",
"id": "max_ctx",
"min": 2048,
"max": 16384,
"step": 512,
"default": 2048,
"tooltip": "The maximum context size the model supports",
"menu_path": "Configuration",
"extra_classes": "",
"refresh_model_inputs": False
})
requested_parameters.append({
"uitype": "slider",
"unit": "float",
"label": "Embedding Compression",
"id": "compress_emb",
"min": 1,
"max": 8,
"step": 0.25,
"default": 1,
"tooltip": "If the model requires compressed embeddings, set them here",
"menu_path": "Configuration",
"extra_classes": "",
"refresh_model_inputs": False
})
requested_parameters.append({
"uitype": "slider",
"unit": "float",
"label": "NTK alpha",
"id": "ntk_alpha",
"min": 1,
"max": 32,
"step": 0.25,
"default": 1,
"tooltip": "NTK alpha value",
"menu_path": "Configuration",
"extra_classes": "",
"refresh_model_inputs": False
})
return requested_parameters return requested_parameters
def set_input_parameters(self, parameters): def set_input_parameters(self, parameters):
@@ -387,6 +427,10 @@ class model_backend(InferenceModel):
self.model_config.device_map.lm_head = "cuda:0" self.model_config.device_map.lm_head = "cuda:0"
self.model_config.device_map.norm = "cuda:0" self.model_config.device_map.norm = "cuda:0"
self.model_config.max_seq_len = parameters["max_ctx"]
self.model_config.compress_pos_emb = parameters["compress_emb"]
self.model_config.alpha_value = parameters["ntk_alpha"]
# Disable half2 for HIP # Disable half2 for HIP
self.model_config.rmsnorm_no_half2 = bool(torch.version.hip) self.model_config.rmsnorm_no_half2 = bool(torch.version.hip)
self.model_config.rope_no_half2 = bool(torch.version.hip) self.model_config.rope_no_half2 = bool(torch.version.hip)

View File

@@ -56,7 +56,7 @@ def load_model_gptq_settings(path):
gptq_file = False gptq_file = False
gptq_version = -1 gptq_version = -1
gptq_legacy_files = glob.glob(os.path.join(path, "4bit*.pt")) + glob.glob(os.path.join(path, "4bit*.safetensors")) gptq_legacy_files = glob.glob(os.path.join(path, "*4bit*.pt")) + glob.glob(os.path.join(path, "*4bit*.safetensors"))
if "gptq_bits" in js: if "gptq_bits" in js:
gptq_model = True gptq_model = True
gptq_bits = js["gptq_bits"] gptq_bits = js["gptq_bits"]
@@ -70,7 +70,7 @@ def load_model_gptq_settings(path):
gptq_bits = 4 gptq_bits = 4
gptq_file = gptq_legacy_files[0] gptq_file = gptq_legacy_files[0]
fname = Path(gptq_file).parts[-1] fname = Path(gptq_file).parts[-1]
g = re.findall("^(?:4bit)(?:-)(\\d+)(?:g-?)", fname) g = re.findall("(?:4bit)(?:-)(\\d+)(?:g-?)", fname)
gptq_groupsize = int(g[0]) if g else -1 gptq_groupsize = int(g[0]) if g else -1
gptq_version = -1 gptq_version = -1