Model: WIP horde and API tests

This commit is contained in:
somebody
2023-03-13 14:11:06 -05:00
parent cd8ccf0a5e
commit 0320678b27
4 changed files with 46 additions and 18 deletions

View File

@@ -1,8 +1,12 @@
# We have to go through aiserver to initalize koboldai_vars :(
import torch
# We have to go through aiserver to initalize koboldai_vars :(
from aiserver import GenericHFTorchInferenceModel
from aiserver import koboldai_vars
from modeling.inference_model import InferenceModel
from modeling.inference_models.api import APIInferenceModel
from modeling.inference_models.horde import HordeInferenceModel
model: InferenceModel
@@ -12,6 +16,7 @@ TEST_PROMPT = "Once upon a time I found myself"
TEST_GEN_TOKEN_COUNT = 20
TEST_SEED = 1337
# HF Torch
def test_generic_hf_torch_load() -> None:
global model
@@ -29,11 +34,11 @@ def test_generic_hf_torch_low_mem_load() -> None:
GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load()
def test_model_gen() -> None:
def test_torch_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
print(x.decoded)
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)"
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
@@ -42,3 +47,28 @@ def test_model_gen() -> None:
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
print(x)
# Horde
def test_horde_load() -> None:
global model
# TODO: Make this a property and sync it with kaivars
koboldai_vars.cluster_requested_models = []
model = HordeInferenceModel()
model.load()
def test_horde_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)
# API
def test_api_load() -> None:
global model
model = APIInferenceModel()
model.load()
def test_api_inference() -> None:
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
# NOTE: Below test is flakey due to Horde worker-defined constraints
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
print(x)