Fix ntk alpha

This commit is contained in:
0cc4m
2023-07-23 21:56:48 +02:00
parent 31a984aa3d
commit 49740aa5ab

View File

@@ -430,6 +430,7 @@ class model_backend(InferenceModel):
self.model_config.max_seq_len = parameters["max_ctx"]
self.model_config.compress_pos_emb = parameters["compress_emb"]
self.model_config.alpha_value = parameters["ntk_alpha"]
self.model_config.calculate_rotary_embedding_base()
# Disable half2 for HIP
self.model_config.rmsnorm_no_half2 = bool(torch.version.hip)