From 0320678b27f2de7c0222bdcc2ce6f7757f05e8e0 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 13 Mar 2023 14:11:06 -0500 Subject: [PATCH] Model: WIP horde and API tests --- modeling/inference_models/api.py | 19 +++++++------- modeling/inference_models/hf_torch.py | 1 - modeling/inference_models/horde.py | 8 +++--- modeling/test_generation.py | 36 ++++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 18 deletions(-) diff --git a/modeling/inference_models/api.py b/modeling/inference_models/api.py index f36e7205..70cea10b 100644 --- a/modeling/inference_models/api.py +++ b/modeling/inference_models/api.py @@ -23,10 +23,12 @@ class APIException(Exception): class APIInferenceModel(InferenceModel): + def __init__(self, base_url: str = "http://localhost:5000") -> None: + super().__init__() + self.base_url = base_url + def _load(self, save_model: bool, initial_load: bool) -> None: - tokenizer_id = requests.get( - utils.koboldai_vars.colaburl[:-8] + "/api/v1/model", - ).json()["result"] + tokenizer_id = requests.get(f"{self.base_url}/api/v1/model").json()["result"] self.tokenizer = self._get_tokenizer(tokenizer_id) @@ -73,13 +75,10 @@ class APIInferenceModel(InferenceModel): # Create request while True: - req = requests.post( - utils.koboldai_vars.colaburl[:-8] + "/api/v1/generate", - json=reqdata, - ) - if ( - req.status_code == 503 - ): # Server is currently generating something else so poll until it's our turn + req = requests.post(f"{self.base_url}/api/v1/generate", json=reqdata) + + if req.status_code == 503: + # Server is currently generating something else so poll until it's our turn time.sleep(1) continue diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 3d4b7029..61536507 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -471,7 +471,6 @@ class HFTorchInferenceModel(HFInferenceModel): additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else [] if seed is not None: - print("seeding", seed) torch.manual_seed(seed) with torch.no_grad(): diff --git a/modeling/inference_models/horde.py b/modeling/inference_models/horde.py index 3e54df07..c6294374 100644 --- a/modeling/inference_models/horde.py +++ b/modeling/inference_models/horde.py @@ -4,7 +4,7 @@ import time import torch import requests import numpy as np -from typing import List, Union +from typing import List, Optional, Union import utils from logger import logger @@ -87,7 +87,7 @@ class HordeInferenceModel(InferenceModel): try: # Create request req = requests.post( - utils.koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async", + f"{utils.koboldai_vars.horde_url}/api/v2/generate/text/async", json=cluster_metadata, headers=cluster_headers, ) @@ -102,8 +102,8 @@ class HordeInferenceModel(InferenceModel): raise HordeException(errmsg) elif not req.ok: errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console." + logger.error(req.url) logger.error(errmsg) - logger.error(f"HTTP {req.status_code}!!!") logger.error(req.text) raise HordeException(errmsg) @@ -125,7 +125,7 @@ class HordeInferenceModel(InferenceModel): while not finished: try: req = requests.get( - f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}", + f"{utils.koboldai_vars.horde_url}/api/v2/generate/text/status/{request_id}", headers=cluster_agent_headers, ) except requests.exceptions.ConnectionError: diff --git a/modeling/test_generation.py b/modeling/test_generation.py index a3efa7bc..947a83c4 100644 --- a/modeling/test_generation.py +++ b/modeling/test_generation.py @@ -1,8 +1,12 @@ -# We have to go through aiserver to initalize koboldai_vars :( import torch + +# We have to go through aiserver to initalize koboldai_vars :( from aiserver import GenericHFTorchInferenceModel from aiserver import koboldai_vars + from modeling.inference_model import InferenceModel +from modeling.inference_models.api import APIInferenceModel +from modeling.inference_models.horde import HordeInferenceModel model: InferenceModel @@ -12,6 +16,7 @@ TEST_PROMPT = "Once upon a time I found myself" TEST_GEN_TOKEN_COUNT = 20 TEST_SEED = 1337 +# HF Torch def test_generic_hf_torch_load() -> None: global model @@ -29,11 +34,11 @@ def test_generic_hf_torch_low_mem_load() -> None: GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load() -def test_model_gen() -> None: +def test_torch_inference() -> None: x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) print(x.decoded) assert len(x.encoded) == 1, "Bad output shape (too many batches!)" - assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)" + assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) @@ -42,3 +47,28 @@ def test_model_gen() -> None: ), f"Faulty full determinism! {x.decoded} vs {y.decoded}" print(x) + +# Horde +def test_horde_load() -> None: + global model + # TODO: Make this a property and sync it with kaivars + koboldai_vars.cluster_requested_models = [] + model = HordeInferenceModel() + model.load() + +def test_horde_inference() -> None: + x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) + assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" + print(x) + +# API +def test_api_load() -> None: + global model + model = APIInferenceModel() + model.load() + +def test_api_inference() -> None: + x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) + # NOTE: Below test is flakey due to Horde worker-defined constraints + # assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})" + print(x) \ No newline at end of file