From 52f5d879061c7ce593fe05a417466d83425f0ad6 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 26 May 2023 11:25:28 -0400 Subject: [PATCH] Fix horde tokenizer --- modeling/inference_models/horde/class.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modeling/inference_models/horde/class.py b/modeling/inference_models/horde/class.py index 2cc01708..f7da6604 100644 --- a/modeling/inference_models/horde/class.py +++ b/modeling/inference_models/horde/class.py @@ -30,6 +30,7 @@ class model_backend(InferenceModel): self.key = "0000000000" self.models = self.get_cluster_models() self.model_name = "Horde" + self.model = [] # Do not allow API to be served over the API @@ -114,7 +115,7 @@ class model_backend(InferenceModel): engines = req.json() try: - engines = [{"text": "all", "value": "all"}] + [{"text": en["name"], "value": en["name"]} for en in engines] + engines = [{"text": "All", "value": "all"}] + [{"text": en["name"], "value": en["name"]} for en in engines] except: logger.error(engines) raise @@ -127,10 +128,14 @@ class model_backend(InferenceModel): return engines def _load(self, save_model: bool, initial_load: bool) -> None: + tokenizer_name = "gpt2" + if len(self.model) > 0: + if self.model[0] == "all" and len(self.model) > 1: + tokenizer_name = self.model[1] + else: + tokenizer_name = self.model[0] self.tokenizer = self._get_tokenizer( - self.model - #if len(self.model) > 0 - #else "gpt2", + tokenizer_name ) def _save_settings(self):