From c753671ac14850a2528c0e1028816a12ca8005ac Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 27 Jun 2023 07:39:37 +0200 Subject: [PATCH] Add exllama superhot positional embeddings compression support --- modeling/inference_models/exllama/class.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/modeling/inference_models/exllama/class.py b/modeling/inference_models/exllama/class.py index 995f5874..19478cc8 100644 --- a/modeling/inference_models/exllama/class.py +++ b/modeling/inference_models/exllama/class.py @@ -106,11 +106,18 @@ class model_backend(InferenceModel): return self.path or os.path.join("models", self.model_name.replace("/", "_")) def _load_config(self, model_name, model_path): + config = False if model_path is not None and os.path.exists(model_path): - return ExLlamaConfig(os.path.join(model_path, "config.json")) - if(os.path.exists("models/{}".format(model_name.replace('/', '_')))): - return ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json")) - return False + config = ExLlamaConfig(os.path.join(model_path, "config.json")) + if not config and os.path.exists("models/{}".format(model_name.replace('/', '_'))): + config = ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json")) + + if config and "superhot" in model_name.lower(): + # Set compress_pos_emb factor + config.max_seq_len = 8192 + config.compress_pos_emb = 4.0 + + return config def _load(self, save_model: bool, initial_load: bool) -> None: self.model = self._get_model(self.get_local_model_path(), {}) @@ -277,11 +284,6 @@ class model_backend(InferenceModel): else: gen_in = prompt_tokens - self.generator.settings.temperature = max(gen_settings.temp, 0.01) - self.generator.settings.top_k = gen_settings.top_k if gen_settings.top_k > 0 else 10000 - self.generator.settings.top_p = gen_settings.top_p - self.generator.settings.min_p = 0.0 - self.generator.gen_begin_reuse(gen_in) for i in range(max_new):