Add exllama superhot positional embeddings compression support

This commit is contained in:
0cc4m
2023-06-27 07:39:37 +02:00
parent adad81639d
commit c753671ac1

View File

@@ -106,11 +106,18 @@ class model_backend(InferenceModel):
return self.path or os.path.join("models", self.model_name.replace("/", "_"))
def _load_config(self, model_name, model_path):
config = False
if model_path is not None and os.path.exists(model_path):
return ExLlamaConfig(os.path.join(model_path, "config.json"))
if(os.path.exists("models/{}".format(model_name.replace('/', '_')))):
return ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json"))
return False
config = ExLlamaConfig(os.path.join(model_path, "config.json"))
if not config and os.path.exists("models/{}".format(model_name.replace('/', '_'))):
config = ExLlamaConfig(os.path.join("models/{}".format(model_name.replace('/', '_')), "config.json"))
if config and "superhot" in model_name.lower():
# Set compress_pos_emb factor
config.max_seq_len = 8192
config.compress_pos_emb = 4.0
return config
def _load(self, save_model: bool, initial_load: bool) -> None:
self.model = self._get_model(self.get_local_model_path(), {})
@@ -277,11 +284,6 @@ class model_backend(InferenceModel):
else:
gen_in = prompt_tokens
self.generator.settings.temperature = max(gen_settings.temp, 0.01)
self.generator.settings.top_k = gen_settings.top_k if gen_settings.top_k > 0 else 10000
self.generator.settings.top_p = gen_settings.top_p
self.generator.settings.min_p = 0.0
self.generator.gen_begin_reuse(gen_in)
for i in range(max_new):