Model: WIP horde and API tests

This commit is contained in:
somebody
2023-03-13 14:11:06 -05:00
parent cd8ccf0a5e
commit 0320678b27
4 changed files with 46 additions and 18 deletions

View File

@@ -23,10 +23,12 @@ class APIException(Exception):
class APIInferenceModel(InferenceModel): class APIInferenceModel(InferenceModel):
def __init__(self, base_url: str = "http://localhost:5000") -> None:
super().__init__()
self.base_url = base_url
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
tokenizer_id = requests.get( tokenizer_id = requests.get(f"{self.base_url}/api/v1/model").json()["result"]
utils.koboldai_vars.colaburl[:-8] + "/api/v1/model",
).json()["result"]
self.tokenizer = self._get_tokenizer(tokenizer_id) self.tokenizer = self._get_tokenizer(tokenizer_id)
@@ -73,13 +75,10 @@ class APIInferenceModel(InferenceModel):
# Create request # Create request
while True: while True:
req = requests.post( req = requests.post(f"{self.base_url}/api/v1/generate", json=reqdata)
utils.koboldai_vars.colaburl[:-8] + "/api/v1/generate",
json=reqdata, if req.status_code == 503:
) # Server is currently generating something else so poll until it's our turn
if (
req.status_code == 503
): # Server is currently generating something else so poll until it's our turn
time.sleep(1) time.sleep(1)
continue continue

View File

@@ -471,7 +471,6 @@ class HFTorchInferenceModel(HFInferenceModel):
additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else [] additional_bad_words_ids = [self.tokenizer.encode("\n")] if single_line else []
if seed is not None: if seed is not None:
print("seeding", seed)
torch.manual_seed(seed) torch.manual_seed(seed)
with torch.no_grad(): with torch.no_grad():

View File

@@ -4,7 +4,7 @@ import time
import torch import torch
import requests import requests
import numpy as np import numpy as np
from typing import List, Union from typing import List, Optional, Union
import utils import utils
from logger import logger from logger import logger
@@ -87,7 +87,7 @@ class HordeInferenceModel(InferenceModel):
try: try:
# Create request # Create request
req = requests.post( req = requests.post(
utils.koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async", f"{utils.koboldai_vars.horde_url}/api/v2/generate/text/async",
json=cluster_metadata, json=cluster_metadata,
headers=cluster_headers, headers=cluster_headers,
) )
@@ -102,8 +102,8 @@ class HordeInferenceModel(InferenceModel):
raise HordeException(errmsg) raise HordeException(errmsg)
elif not req.ok: elif not req.ok:
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console." errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
logger.error(req.url)
logger.error(errmsg) logger.error(errmsg)
logger.error(f"HTTP {req.status_code}!!!")
logger.error(req.text) logger.error(req.text)
raise HordeException(errmsg) raise HordeException(errmsg)
@@ -125,7 +125,7 @@ class HordeInferenceModel(InferenceModel):
while not finished: while not finished:
try: try:
req = requests.get( req = requests.get(
f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}", f"{utils.koboldai_vars.horde_url}/api/v2/generate/text/status/{request_id}",
headers=cluster_agent_headers, headers=cluster_agent_headers,
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:

View File

@@ -1,8 +1,12 @@
# We have to go through aiserver to initalize koboldai_vars :(
import torch import torch
# We have to go through aiserver to initalize koboldai_vars :(
from aiserver import GenericHFTorchInferenceModel from aiserver import GenericHFTorchInferenceModel
from aiserver import koboldai_vars from aiserver import koboldai_vars
from modeling.inference_model import InferenceModel from modeling.inference_model import InferenceModel
from modeling.inference_models.api import APIInferenceModel
from modeling.inference_models.horde import HordeInferenceModel
model: InferenceModel model: InferenceModel
@@ -12,6 +16,7 @@ TEST_PROMPT = "Once upon a time I found myself"
TEST_GEN_TOKEN_COUNT = 20 TEST_GEN_TOKEN_COUNT = 20
TEST_SEED = 1337 TEST_SEED = 1337
# HF Torch
def test_generic_hf_torch_load() -> None: def test_generic_hf_torch_load() -> None:
global model global model
@@ -29,11 +34,11 @@ def test_generic_hf_torch_low_mem_load() -> None:
GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load() GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load()
def test_model_gen() -> None: def test_torch_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded) print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)" assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)" assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED) y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
@@ -42,3 +47,28 @@ def test_model_gen() -> None:
), f"Faulty full determinism! {x.decoded} vs {y.decoded}" ), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
print(x) print(x)
# Horde
def test_horde_load() -> None:
global model
# TODO: Make this a property and sync it with kaivars
koboldai_vars.cluster_requested_models = []
model = HordeInferenceModel()
model.load()
def test_horde_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)
# API
def test_api_load() -> None:
global model
model = APIInferenceModel()
model.load()
def test_api_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
# NOTE: Below test is flakey due to Horde worker-defined constraints
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)