From 3b05359e4b6c4811ac0fbc7fba76e5bda8bc260a Mon Sep 17 00:00:00 2001 From: somebody Date: Tue, 7 Mar 2023 16:54:08 -0600 Subject: [PATCH] Model: Refuse to serve certain models over the API --- aiserver.py | 12 ++---------- modeling/inference_model.py | 3 +++ modeling/inference_models/api.py | 5 +++++ modeling/inference_models/colab.py | 7 +++++++ modeling/inference_models/horde.py | 7 +++++++ 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/aiserver.py b/aiserver.py index fc05083f..e3cf40ef 100644 --- a/aiserver.py +++ b/aiserver.py @@ -3555,16 +3555,8 @@ def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum): return genout def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=False, use_authors_note=False): - if(koboldai_vars.model == "Colab"): - raise NotImplementedError("API generation is not supported in old Colab API mode.") - elif(koboldai_vars.model == "API"): - raise NotImplementedError("API generation is not supported in API mode.") - elif(koboldai_vars.model == "CLUSTER"): - raise NotImplementedError("API generation is not supported in API mode.") - elif(koboldai_vars.model == "OAI"): - raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.") - elif(koboldai_vars.model == "ReadOnly"): - raise NotImplementedError("API generation is not supported in read-only mode; please load a model and then try again.") + if not model or not model.capabilties.api_host: + raise NotImplementedError(f"API generation isn't allowed on model '{koboldai_vars.model}'") data = applyinputformatting(data) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 5b40733e..5157acb0 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -156,6 +156,9 @@ class ModelCapabilities: # TODO: Support non-live probabilities from APIs post_token_probs: bool = False + # Some models cannot be hosted over the API, namely the API itself. + api_host: bool = True + class InferenceModel: """Root class for all models.""" diff --git a/modeling/inference_models/api.py b/modeling/inference_models/api.py index 9fc98abb..83fcd7ab 100644 --- a/modeling/inference_models/api.py +++ b/modeling/inference_models/api.py @@ -14,6 +14,7 @@ from modeling.inference_model import ( GenerationResult, GenerationSettings, InferenceModel, + ModelCapabilities, ) @@ -26,8 +27,12 @@ class APIInferenceModel(InferenceModel): tokenizer_id = requests.get( utils.koboldai_vars.colaburl[:-8] + "/api/v1/model", ).json()["result"] + self.tokenizer = self._get_tokenizer(tokenizer_id) + # Do not allow API to be served over the API + self.capabilties = ModelCapabilities(api_host=False) + def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor], diff --git a/modeling/inference_models/colab.py b/modeling/inference_models/colab.py index faf06299..c42807da 100644 --- a/modeling/inference_models/colab.py +++ b/modeling/inference_models/colab.py @@ -11,6 +11,7 @@ from modeling.inference_model import ( GenerationResult, GenerationSettings, InferenceModel, + ModelCapabilities, ) @@ -19,6 +20,12 @@ class ColabException(Exception): class ColabInferenceModel(InferenceModel): + def __init__(self) -> None: + super().__init__() + + # Do not allow API to be served over the API + self.capabilties = ModelCapabilities(api_host=False) + def _load(self, save_model: bool, initial_load: bool) -> None: self.tokenizer = self._get_tokenizer("EleutherAI/gpt-neo-2.7B") diff --git a/modeling/inference_models/horde.py b/modeling/inference_models/horde.py index c5498512..90f7a474 100644 --- a/modeling/inference_models/horde.py +++ b/modeling/inference_models/horde.py @@ -13,6 +13,7 @@ from modeling.inference_model import ( GenerationResult, GenerationSettings, InferenceModel, + ModelCapabilities, ) @@ -21,6 +22,12 @@ class HordeException(Exception): class HordeInferenceModel(InferenceModel): + def __init__(self) -> None: + super().__init__() + + # Do not allow API to be served over the API + self.capabilties = ModelCapabilities(api_host=False) + def _load(self, save_model: bool, initial_load: bool) -> None: self.tokenizer = self._get_tokenizer( utils.koboldai_vars.cluster_requested_models[0]