diff --git a/modeling/test_generation.py b/modeling/test_generation.py new file mode 100644 index 00000000..0d9173b5 --- /dev/null +++ b/modeling/test_generation.py @@ -0,0 +1,47 @@ +# We have to go through aiserver to initalize koboldai_vars :( +import torch +from aiserver import GenericHFTorchInferenceModel +from aiserver import koboldai_vars +from modeling.inference_model import InferenceModel + +model: InferenceModel + +# Preferably teeny tiny +TEST_MODEL_HF_ID = "EleutherAI/pythia-70m" +TEST_PROMPT = "Once upon a time I found myself" +TEST_GEN_TOKEN_COUNT = 20 + + +def test_generic_hf_torch_load() -> None: + global model + model = GenericHFTorchInferenceModel( + TEST_MODEL_HF_ID, lazy_load=False, low_mem=False + ) + model.load() + + +def test_generic_hf_torch_lazy_load() -> None: + GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=True, low_mem=False).load() + + +def test_generic_hf_torch_low_mem_load() -> None: + GenericHFTorchInferenceModel(TEST_MODEL_HF_ID, lazy_load=False, low_mem=True).load() + + +def test_model_gen() -> None: + # This should probably be supported in the model interface! + koboldai_vars.full_determinism = True + koboldai_vars.seed = 1337 + + x = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) + print(x.decoded) + assert len(x.encoded) == 1, "Bad output shape (too many batches!)" + assert len(x.encoded[0]) == 20, "Wrong token amount (requested 20)" + + y = model.raw_generate(TEST_PROMPT, max_new=TEST_GEN_TOKEN_COUNT) + + assert torch.equal( + x.encoded[0], y.encoded[0] + ), f"Faulty full determinism! {x.decoded} vs {y.decoded}" + + print(x) \ No newline at end of file