mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
import torch
|
|
|
|
# We have to go through aiserver to initalize koboldai_vars :(
|
|
from aiserver import koboldai_vars
|
|
|
|
from modeling.inference_model import InferenceModel
|
|
from modeling.inference_models.api import APIInferenceModel
|
|
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
|
from modeling.inference_models.horde import HordeInferenceModel
|
|
|
|
model: InferenceModel
|
|
|
|
# Preferably teeny tiny
|
|
TEST_MODEL_HF_ID = "EleutherAI/pythia-70m"
|
|
TEST_PROMPT = "Once upon a time I found myself"
|
|
TEST_GEN_TOKEN_COUNT = 20
|
|
TEST_SEED = 1337
|
|
|
|
# HF Torch
|
|
|
|
|
|
def test_generic_hf_torch_load() -> None:
|
|
global model
|
|
model = GenericHFTorchInferenceModel(
|
|
TEST_MODEL_HF_ID, lazy_load=False, low_mem=False
|
|
)
|
|
model.load()
|
|
|
|
|
|
def test_generic_hf_torch_lazy_load() -> None:
|
|
GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=True, low_mem=False).load()
|
|
|
|
|
|
def test_generic_hf_torch_low_mem_load() -> None:
|
|
GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load()
|
|
|
|
|
|
def test_torch_inference() -> None:
|
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
|
print(x.decoded)
|
|
assert len(x.encoded) == 1, "Bad output shape (too many batches!)"
|
|
assert (
|
|
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
|
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
|
|
|
y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
|
|
|
assert torch.equal(
|
|
x.encoded[0], y.encoded[0]
|
|
), f"Faulty full determinism! {x.decoded} vs {y.decoded}"
|
|
|
|
print(x)
|
|
|
|
|
|
# Horde
|
|
def test_horde_load() -> None:
|
|
global model
|
|
# TODO: Make this a property and sync it with kaivars
|
|
koboldai_vars.cluster_requested_models = []
|
|
model = HordeInferenceModel()
|
|
model.load()
|
|
|
|
|
|
def test_horde_inference() -> None:
|
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
|
assert (
|
|
len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT
|
|
), f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
|
print(x)
|
|
|
|
|
|
# API
|
|
def test_api_load() -> None:
|
|
global model
|
|
model = APIInferenceModel()
|
|
model.load()
|
|
|
|
|
|
def test_api_inference() -> None:
|
|
x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT, seed=TEST_SEED)
|
|
# NOTE: Below test is flakey due to Horde worker-defined constraints
|
|
# assert len(x.encoded[0]) == TEST_GEN_TOKEN_COUNT, f"Wrong token amount (requested {TEST_GEN_TOKEN_COUNT})"
|
|
print(x)
|